Add Deepspeech2 online and offline in cli

pull/1356/head
huangyuxin 3 years ago
parent 26524031d2
commit 38edfd1a89

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import io
import os import os
import sys import sys
from typing import List from typing import List
@ -23,9 +22,9 @@ import librosa
import numpy as np import numpy as np
import paddle import paddle
import soundfile import soundfile
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from ..download import get_path_from_url
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import cli_register from ..utils import cli_register
@ -64,14 +63,47 @@ pretrained_models = {
'ckpt_path': 'ckpt_path':
'exp/transformer/checkpoints/avg_10', 'exp/transformer/checkpoints/avg_10',
}, },
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'd5e076217cf60486519f72c217d21b9b',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
} }
model_alias = { model_alias = {
"deepspeech2offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", "deepspeech2offline":
"deepspeech2online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2:DeepSpeech2Model",
"conformer": "paddlespeech.s2t.models.u2:U2Model", "deepspeech2online":
"transformer": "paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"wenetspeech": "paddlespeech.s2t.models.u2:U2Model", "conformer":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":
"paddlespeech.s2t.models.u2:U2Model",
} }
@ -95,7 +127,8 @@ class ASRExecutor(BaseExecutor):
'--lang', '--lang',
type=str, type=str,
default='zh', default='zh',
help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]') help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]'
)
self.parser.add_argument( self.parser.add_argument(
"--sample_rate", "--sample_rate",
type=int, type=int,
@ -111,7 +144,10 @@ class ASRExecutor(BaseExecutor):
'--decode_method', '--decode_method',
type=str, type=str,
default='attention_rescoring', default='attention_rescoring',
choices=['ctc_greedy_search', 'ctc_prefix_beam_search', 'attention', 'attention_rescoring'], choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention',
'attention_rescoring'
],
help='only support transformer and conformer model') help='only support transformer and conformer model')
self.parser.add_argument( self.parser.add_argument(
'--ckpt_path', '--ckpt_path',
@ -187,13 +223,21 @@ class ASRExecutor(BaseExecutor):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join(res_path, self.config.decode.lang_model_path) self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type, vocab=self.vocab)
vocab=self.vocab) lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
self.config.spm_model_prefix = os.path.join(self.res_path, self.config.spm_model_prefix) self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
@ -319,6 +363,13 @@ class ASRExecutor(BaseExecutor):
""" """
return self._outputs["result"] return self._outputs["result"]
def download_lm(self, url, lm_dir, md5sum):
download_path = get_path_from_url(
url=url,
root_dir=lm_dir,
md5sum=md5sum,
decompress=False, )
def _pcm16to32(self, audio): def _pcm16to32(self, audio):
assert (audio.dtype == np.int16) assert (audio.dtype == np.int16)
audio = audio.astype("float32") audio = audio.astype("float32")
@ -435,7 +486,8 @@ class ASRExecutor(BaseExecutor):
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate, force_yes) self._check(audio_file, sample_rate, force_yes)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method, ckpt_path) self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path)
self.preprocess(model, audio_file) self.preprocess(model, audio_file)
self.infer(model) self.infer(model)
res = self.postprocess() # Retrieve result of asr. res = self.postprocess() # Retrieve result of asr.

Loading…
Cancel
Save