Final cleaning; Modified SSL/infer.py and README for wavlm inclusion in model options

pull/3242/head
jiamingkong 1 year ago
parent ba874db5dc
commit 8432e8626f

@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
``` ```
Arguments: Arguments:
- `input`(required): Audio file to recognize. - `input`(required): Audio file to recognize.
- `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert]. - `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert, wavlm].
- `task`: Output type. Default: `asr`. - `task`: Output type. Default: `asr`.
- `lang`: Model language. Default: `en`. - `lang`: Model language. Default: `en`.
- `sample_rate`: Sample rate of the model. Default: `16000`. - `sample_rate`: Sample rate of the model. Default: `16000`.

@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
``` ```
参数: 参数:
- `input`(必须输入):用于识别的音频文件。 - `input`(必须输入):用于识别的音频文件。
- `model`ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert]。 - `model`ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert, wavlm]。
- `task`:输出类别,默认值:`asr`。 - `task`:输出类别,默认值:`asr`。
- `lang`:模型语言,默认值:`en`。 - `lang`:模型语言,默认值:`en`。
- `sample_rate`:音频采样率,默认值:`16000`。 - `sample_rate`:音频采样率,默认值:`16000`。

@ -4,7 +4,7 @@ set -e
. ./path.sh || exit 1; . ./path.sh || exit 1;
. ./cmd.sh || exit 1; . ./cmd.sh || exit 1;
gpus=1,2,3 gpus=0,1,2
stage=0 stage=0
stop_stage=3 stop_stage=3
conf_path=conf/wavlmASR.yaml conf_path=conf/wavlmASR.yaml

@ -52,7 +52,7 @@ class SSLExecutor(BaseExecutor):
'--model', '--model',
type=str, type=str,
default='wav2vec2', default='wav2vec2',
choices=['wav2vec2', 'hubert'], choices=['wav2vec2', 'hubert', "wavlm"],
help='Choose model type of asr task.') help='Choose model type of asr task.')
self.parser.add_argument( self.parser.add_argument(
'--task', '--task',
@ -157,6 +157,12 @@ class SSLExecutor(BaseExecutor):
elif lang == 'zh': elif lang == 'zh':
logger.error("zh hubertASR is not supported yet") logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str tag = model_prefix + '-' + lang + '-' + sample_rate_str
elif model_type == 'wavlm':
if lang == "en":
model_prefix = "wavlmASR_librispeech"
elif lang == "zh":
logger.error("zh wavlmASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
else: else:
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None) self.task_resource.set_task_model(tag, version=None)

@ -29,7 +29,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class Wav2vec2Infer(): class WavLMInfer():
def __init__(self, config, args): def __init__(self, config, args):
self.args = args self.args = args
self.config = config self.config = config
@ -99,7 +99,7 @@ def check(audio_file):
def main(config, args): def main(config, args):
Wav2vec2Infer(config, args).run() WavLMInfer(config, args).run()
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save