From 8432e8626fed3d07b2e52b9ca22c77577665d040 Mon Sep 17 00:00:00 2001 From: jiamingkong Date: Wed, 31 May 2023 15:07:50 +0800 Subject: [PATCH] Final cleaning; Modified SSL/infer.py and README for wavlm inclusion in model options --- demos/speech_ssl/README.md | 2 +- demos/speech_ssl/README_cn.md | 2 +- examples/librispeech/asr5/run.sh | 2 +- paddlespeech/cli/ssl/infer.py | 8 +++++++- paddlespeech/s2t/exps/wavlm/bin/test_wav.py | 4 ++-- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md index 937cd95a..ef9b2237 100644 --- a/demos/speech_ssl/README.md +++ b/demos/speech_ssl/README.md @@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav ``` Arguments: - `input`(required): Audio file to recognize. - - `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert]. + - `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert, wavlm]. - `task`: Output type. Default: `asr`. - `lang`: Model language. Default: `en`. - `sample_rate`: Sample rate of the model. Default: `16000`. diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md index 8455d2c7..a18c778a 100644 --- a/demos/speech_ssl/README_cn.md +++ b/demos/speech_ssl/README_cn.md @@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav ``` 参数: - `input`(必须输入):用于识别的音频文件。 - - `model`:ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert]。 + - `model`:ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert, wavlm]。 - `task`:输出类别,默认值:`asr`。 - `lang`:模型语言,默认值:`en`。 - `sample_rate`:音频采样率,默认值:`16000`。 diff --git a/examples/librispeech/asr5/run.sh b/examples/librispeech/asr5/run.sh index 877891fc..9634bc8c 100644 --- a/examples/librispeech/asr5/run.sh +++ b/examples/librispeech/asr5/run.sh @@ -4,7 +4,7 @@ set -e . ./path.sh || exit 1; . ./cmd.sh || exit 1; -gpus=1,2,3 +gpus=0,1,2 stage=0 stop_stage=3 conf_path=conf/wavlmASR.yaml diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py index bc3c632d..9b4b0280 100644 --- a/paddlespeech/cli/ssl/infer.py +++ b/paddlespeech/cli/ssl/infer.py @@ -52,7 +52,7 @@ class SSLExecutor(BaseExecutor): '--model', type=str, default='wav2vec2', - choices=['wav2vec2', 'hubert'], + choices=['wav2vec2', 'hubert', "wavlm"], help='Choose model type of asr task.') self.parser.add_argument( '--task', @@ -157,6 +157,12 @@ class SSLExecutor(BaseExecutor): elif lang == 'zh': logger.error("zh hubertASR is not supported yet") tag = model_prefix + '-' + lang + '-' + sample_rate_str + elif model_type == 'wavlm': + if lang == "en": + model_prefix = "wavlmASR_librispeech" + elif lang == "zh": + logger.error("zh wavlmASR is not supported yet") + tag = model_prefix + '-' + lang + '-' + sample_rate_str else: tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(tag, version=None) diff --git a/paddlespeech/s2t/exps/wavlm/bin/test_wav.py b/paddlespeech/s2t/exps/wavlm/bin/test_wav.py index 468cca3d..e6c07629 100644 --- a/paddlespeech/s2t/exps/wavlm/bin/test_wav.py +++ b/paddlespeech/s2t/exps/wavlm/bin/test_wav.py @@ -29,7 +29,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig logger = Log(__name__).getlog() -class Wav2vec2Infer(): +class WavLMInfer(): def __init__(self, config, args): self.args = args self.config = config @@ -99,7 +99,7 @@ def check(audio_file): def main(config, args): - Wav2vec2Infer(config, args).run() + WavLMInfer(config, args).run() if __name__ == "__main__":