You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/deepspeech/models/st_interface.py

76 lines
2.5 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
from .asr_interface import ASRInterface
from deepspeech.utils.dynamic_import import dynamic_import
class STInterface(ASRInterface):
"""ST Interface model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def translate(self,
x,
trans_args,
char_list=None,
rnnlm=None,
ensemble_models=[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("translate method is not implemented")
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
predefined_st = {
"transformer": "deepspeech.models.u2_st:U2STModel",
}
def dynamic_import_st(module):
"""Import ST models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_st`
Returns:
type: ST class
"""
model_class = dynamic_import(module, predefined_st)
assert issubclass(model_class,
STInterface), f"{module} does not implement STInterface"
return model_class