|
|
|
@ -12,7 +12,9 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import ast
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Union
|
|
|
|
from typing import Union
|
|
|
|
@ -77,7 +79,7 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
prog='paddlespeech.cls', add_help=True)
|
|
|
|
prog='paddlespeech.cls', add_help=True)
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--input', type=str, required=True, help='Audio file to classify.')
|
|
|
|
'--input', type=str, default=None, help='Audio file to classify.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--model',
|
|
|
|
'--model',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
@ -109,6 +111,11 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default=paddle.get_device(),
|
|
|
|
default=paddle.get_device(),
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--job_dump_result',
|
|
|
|
|
|
|
|
type=ast.literal_eval,
|
|
|
|
|
|
|
|
default=False,
|
|
|
|
|
|
|
|
help='Save job result into file.')
|
|
|
|
|
|
|
|
|
|
|
|
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
|
|
|
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@ -214,7 +221,7 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
ret = ''
|
|
|
|
ret = ''
|
|
|
|
for idx in topk_idx:
|
|
|
|
for idx in topk_idx:
|
|
|
|
label, score = self._label_list[idx], result[idx]
|
|
|
|
label, score = self._label_list[idx], result[idx]
|
|
|
|
ret += f'{label}: {score}\n'
|
|
|
|
ret += f'{label} {score} '
|
|
|
|
return ret
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
|
|
|
|
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
|
|
|
|
@ -234,18 +241,30 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
label_file = parser_args.label_file
|
|
|
|
label_file = parser_args.label_file
|
|
|
|
cfg_path = parser_args.config
|
|
|
|
cfg_path = parser_args.config
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
audio_file = parser_args.input
|
|
|
|
|
|
|
|
topk = parser_args.topk
|
|
|
|
topk = parser_args.topk
|
|
|
|
device = parser_args.device
|
|
|
|
device = parser_args.device
|
|
|
|
|
|
|
|
job_dump_result = parser_args.job_dump_result
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
task_source = self.get_task_source(parser_args.input)
|
|
|
|
res = self(audio_file, model_type, cfg_path, ckpt_path, label_file,
|
|
|
|
task_results = OrderedDict()
|
|
|
|
topk, device)
|
|
|
|
has_exceptions = False
|
|
|
|
logger.info('CLS Result:\n{}'.format(res))
|
|
|
|
|
|
|
|
return True
|
|
|
|
for id_, input_ in task_source.items():
|
|
|
|
except Exception as e:
|
|
|
|
try:
|
|
|
|
logger.exception(e)
|
|
|
|
res = self(input_, model_type, cfg_path, ckpt_path, label_file,
|
|
|
|
|
|
|
|
topk, device)
|
|
|
|
|
|
|
|
task_results[id_] = res
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
has_exceptions = True
|
|
|
|
|
|
|
|
task_results[id_] = f'{e.__class__.__name__}: {e}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.process_task_results(parser_args.input, task_results,
|
|
|
|
|
|
|
|
job_dump_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if has_exceptions:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
@stats_wrapper
|
|
|
|
@stats_wrapper
|
|
|
|
def __call__(self,
|
|
|
|
def __call__(self,
|
|
|
|
@ -259,7 +278,7 @@ class CLSExecutor(BaseExecutor):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Python API to call an executor.
|
|
|
|
Python API to call an executor.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
audio_file = os.path.abspath(os.path.expanduser(audio_file))
|
|
|
|
paddle.set_device(device)
|
|
|
|
paddle.set_device(device)
|
|
|
|
self._init_from_path(model, config, ckpt_path, label_file)
|
|
|
|
self._init_from_path(model, config, ckpt_path, label_file)
|
|
|
|
self.preprocess(audio_file)
|
|
|
|
self.preprocess(audio_file)
|
|
|
|
|