Final cleaning; Modified SSL/infer.py and README for wavlm inclusion in model options

pull/3242/head
jiamingkong 1 year ago
parent ba874db5dc
commit 8432e8626f

@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
Arguments:
- `input`(required): Audio file to recognize.
- `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert].
- `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert, wavlm].
- `task`: Output type. Default: `asr`.
- `lang`: Model language. Default: `en`.
- `sample_rate`: Sample rate of the model. Default: `16000`.

@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
参数:
- `input`(必须输入):用于识别的音频文件。
- `model`ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert]。
- `model`ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert, wavlm]。
- `task`:输出类别,默认值:`asr`。
- `lang`:模型语言,默认值:`en`。
- `sample_rate`:音频采样率,默认值:`16000`。

@ -4,7 +4,7 @@ set -e
. ./path.sh || exit 1;
. ./cmd.sh || exit 1;
gpus=1,2,3
gpus=0,1,2
stage=0
stop_stage=3
conf_path=conf/wavlmASR.yaml

@ -52,7 +52,7 @@ class SSLExecutor(BaseExecutor):
'--model',
type=str,
default='wav2vec2',
choices=['wav2vec2', 'hubert'],
choices=['wav2vec2', 'hubert', "wavlm"],
help='Choose model type of asr task.')
self.parser.add_argument(
'--task',
@ -157,6 +157,12 @@ class SSLExecutor(BaseExecutor):
elif lang == 'zh':
logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
elif model_type == 'wavlm':
if lang == "en":
model_prefix = "wavlmASR_librispeech"
elif lang == "zh":
logger.error("zh wavlmASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
else:
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None)

@ -29,7 +29,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class Wav2vec2Infer():
class WavLMInfer():
def __init__(self, config, args):
self.args = args
self.config = config
@ -99,7 +99,7 @@ def check(audio_file):
def main(config, args):
Wav2vec2Infer(config, args).run()
WavLMInfer(config, args).run()
if __name__ == "__main__":

Loading…
Cancel
Save