@ -150,10 +150,10 @@ class SSLExecutor(BaseExecutor):
model_prefix = ' wav2vec2ASR_librispeech '
model_prefix = ' wav2vec2ASR_librispeech '
elif lang == ' zh ' :
elif lang == ' zh ' :
model_prefix = ' wav2vec2ASR_aishell1 '
model_prefix = ' wav2vec2ASR_aishell1 '
tag = model_prefix + ' - ' + lang + ' - ' + sample_rate_str
tag = model_prefix + ' - ' + lang + ' - ' + sample_rate_str
elif model_type == ' hubert ' :
elif model_type == ' hubert ' :
if lang == ' en ' :
if lang == ' en ' :
model_prefix = ' hubertASR_librispeech _100 '
model_prefix = ' hubertASR_librispeech -100h '
elif lang == ' zh ' :
elif lang == ' zh ' :
logger . error ( " zh hubertASR is not supported yet " )
logger . error ( " zh hubertASR is not supported yet " )
tag = model_prefix + ' - ' + lang + ' - ' + sample_rate_str
tag = model_prefix + ' - ' + lang + ' - ' + sample_rate_str
@ -185,16 +185,17 @@ class SSLExecutor(BaseExecutor):
self . text_feature = TextFeaturizer (
self . text_feature = TextFeaturizer (
unit_type = self . config . unit_type ,
unit_type = self . config . unit_type ,
vocab = self . config . vocab_filepath )
vocab = self . config . vocab_filepath )
self . config . output_dim = len ( self . config . vocab_filepath )
elif lang == ' zh ' :
elif lang == ' zh ' :
self . text_feature = AutoTokenizer . from_pretrained (
self . text_feature = AutoTokenizer . from_pretrained (
self . config . tokenizer )
self . config . tokenizer )
self . config . output_dim = self . text_feature . vocab_size
self . config . decode . decoding_method = decode_method
self . config . decode . decoding_method = decode_method
model_name = model_ ty pe[ : model_ ty pe. rindex (
model_name = model_ pr efix [ : model_ pr efix . rindex (
' _ ' ) ] # model_type: {model_name}_{dataset}
' _ ' ) ] # model_type: {model_name}_{dataset}
else :
else :
model_name = model_type
model_name = model_type
model_class = self . task_resource . get_model_class ( model_name )
model_class = self . task_resource . get_model_class ( model_name )
model_conf = self . config
model_conf = self . config
model = model_class . from_config ( model_conf )
model = model_class . from_config ( model_conf )
self . model = model
self . model = model
@ -264,8 +265,7 @@ class SSLExecutor(BaseExecutor):
audio = self . _inputs [ " audio " ]
audio = self . _inputs [ " audio " ]
if task == ' asr ' :
if task == ' asr ' :
cfg = self . config . decode
cfg = self . config . decode
logger . debug (
logger . debug ( f " we will use the { model_type } ASR like model. " )
f " we will use the { model_type } ASR like model. " )
try :
try :
result_transcripts = self . model . decode (
result_transcripts = self . model . decode (
audio ,
audio ,
@ -278,7 +278,8 @@ class SSLExecutor(BaseExecutor):
logger . exception ( e )
logger . exception ( e )
else :
else :
logger . debug (
logger . debug (
f " we will use the { model_type } like model to extract audio feature. " )
f " we will use the { model_type } like model to extract audio feature. "
)
try :
try :
out_feature = self . model ( audio [ : , : , 0 ] )
out_feature = self . model ( audio [ : , : , 0 ] )
self . _outputs [ " result " ] = out_feature [ 0 ]
self . _outputs [ " result " ] = out_feature [ 0 ]
@ -455,7 +456,7 @@ class SSLExecutor(BaseExecutor):
if rtf :
if rtf :
k = self . __class__ . __name__
k = self . __class__ . __name__
CLI_TIMER [ k ] [ ' start ' ] . append ( time . time ( ) )
CLI_TIMER [ k ] [ ' start ' ] . append ( time . time ( ) )
self . preprocess ( model, audio_file)
self . preprocess ( audio_file)
self . infer ( model , task )
self . infer ( model , task )
res = self . postprocess ( ) # Retrieve result of asr.
res = self . postprocess ( ) # Retrieve result of asr.