diff --git a/paddlespeech/s2t/decoders/__init__.py b/paddlespeech/s2t/decoders/__init__.py index f04a6d19..8878a76f 100644 --- a/paddlespeech/s2t/decoders/__init__.py +++ b/paddlespeech/s2t/decoders/__init__.py @@ -11,4 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .ctcdecoder import swig_wrapper + +from paddlespeech.s2t.utils.log import Log +logger = Log(__name__).getlog() + +try: + from .ctcdecoder import swig_wrapper +except: + try: + import pip + if int(pip.__version__.split('.')[0])>9: + from pip._internal import main + else: + from pip import main + package_name = 'paddlespeech_ctcdecoders' + main(['install', package_name]) + except Exception as e: + logger.info("paddlespeech_ctcdecoders not installed!") + diff --git a/paddlespeech/s2t/models/ds2/__init__.py b/paddlespeech/s2t/models/ds2/__init__.py index 39bea5bf..094fe928 100644 --- a/paddlespeech/s2t/models/ds2/__init__.py +++ b/paddlespeech/s2t/models/ds2/__init__.py @@ -14,4 +14,19 @@ from .deepspeech2 import DeepSpeech2InferModel from .deepspeech2 import DeepSpeech2Model +try: + import swig_decoders +except: + try: + import pip + if int(pip.__version__.split('.')[0])>9: + from pip._internal import main + else: + from pip import main + package_name = 'paddlespeech_ctcdecoders' + main(['install', package_name]) + except: + raise RuntimeError("Can not install package paddlespeech_ctcdecoders on your system. \ + The DeepSpeech2 model is not supported for your system") + __all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] diff --git a/paddlespeech/s2t/models/ds2_online/__init__.py b/paddlespeech/s2t/models/ds2_online/__init__.py index 255000ee..c8f2f55c 100644 --- a/paddlespeech/s2t/models/ds2_online/__init__.py +++ b/paddlespeech/s2t/models/ds2_online/__init__.py @@ -14,4 +14,20 @@ from .deepspeech2 import DeepSpeech2InferModelOnline from .deepspeech2 import DeepSpeech2ModelOnline +try: + import swig_decoders +except: + try: + import pip + if int(pip.__version__.split('.')[0])>9: + from pip._internal import main + else: + from pip import main + package_name = 'paddlespeech_ctcdecoders' + main(['install', package_name]) + except: + raise RuntimeError("Can not install package paddlespeech_ctcdecoders on your system. \ + The DeepSpeech2 model is not supported for your system") + + __all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline'] diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 120abd2b..2d368dae 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -28,8 +28,21 @@ try: from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401 -except Exception as e: - logger.info("ctcdecoder not installed!") +except: + try: + import pip + if int(pip.__version__.split('.')[0])>9: + from pip._internal import main + else: + from pip import main + package_name = 'paddlespeech_ctcdecoders' + main(['install', package_name]) + except Exception as e: + logger.info("paddlespeech_ctcdecoders not installed!") + +#try: +#except Exception as e: +# logger.info("ctcdecoder not installed!") __all__ = ['CTCDecoder'] @@ -51,7 +64,7 @@ class CTCDecoderBase(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. - grad_norm_type (str): Default, None. one of 'instance', 'batch', 'frame', None. + grad_norm_type (str): Default, None. one of 'instance', 'batch', 'frame', None. """ assert check_argument_types() super().__init__() diff --git a/setup.py b/setup.py index 7720ba3f..0816cc78 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,6 @@ requirements = { "nltk", "pandas", "paddleaudio", - "paddlespeech_ctcdecoders", "paddlespeech_feat", "praatio~=4.1", "pypi-kenlm", @@ -70,6 +69,7 @@ requirements = { "ConfigArgParse", "coverage", "gpustat", + "paddlespeech_ctcdecoders", "phkit", "Pillow", "pybind11",