@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
""" Contains U2 model. """
""" Contains U2 model. """
import sys
import time
import time
import logging
import logging
import numpy as np
import numpy as np
@ -256,11 +257,19 @@ class U2Tester(U2Trainer):
cutoff_prob = 1.0 , # Cutoff probability for pruning.
cutoff_prob = 1.0 , # Cutoff probability for pruning.
cutoff_top_n = 40 , # Cutoff number for pruning.
cutoff_top_n = 40 , # Cutoff number for pruning.
lang_model_path = ' models/lm/common_crawl_00.prune01111.trie.klm ' , # Filepath for language model.
lang_model_path = ' models/lm/common_crawl_00.prune01111.trie.klm ' , # Filepath for language model.
decoding_method = ' ctc_beam_search ' , # Decoding method. Options: ctc_beam_search, ctc_greedy
decoding_method = ' attention ' , # Decoding method. Options: 'attention', 'ctc_greedy_search',
# 'ctc_prefix_beam_search', 'attention_rescoring'
error_rate_type = ' wer ' , # Error rate type for evaluation. Options `wer`, 'cer'
error_rate_type = ' wer ' , # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch = 8 , # # of CPUs for beam search.
num_proc_bsearch = 8 , # # of CPUs for beam search.
beam_size = 500 , # Beam search width.
beam_size = 10 , # Beam search width.
batch_size = 128 , # decoding batch size
batch_size = 16 , # decoding batch size
ctc_weight = 0.0 , # ctc weight for attention rescoring decode mode.
decoding_chunk_size = - 1 , # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks = - 1 , # number of left chunks for decoding. Defaults to -1.
simulate_streaming = False , # simulate streaming inference. Defaults to False.
) )
) )
if config is not None :
if config is not None :
@ -279,19 +288,19 @@ class U2Tester(U2Trainer):
trans . append ( ' ' . join ( [ chr ( i ) for i in ids ] ) )
trans . append ( ' ' . join ( [ chr ( i ) for i in ids ] ) )
return trans
return trans
def compute_metrics ( self , audio , texts, audio_len, texts _len) :
def compute_metrics ( self , audio , audio_len, texts , texts _len, fout = None ) :
cfg = self . config . decoding
cfg = self . config . decoding
errors_sum , len_refs , num_ins = 0.0 , 0 , 0
errors_sum , len_refs , num_ins = 0.0 , 0 , 0
errors_func = error_rate . char_errors if cfg . error_rate_type == ' cer ' else error_rate . word_errors
errors_func = error_rate . char_errors if cfg . error_rate_type == ' cer ' else error_rate . word_errors
error_rate_func = error_rate . cer if cfg . error_rate_type == ' cer ' else error_rate . wer
error_rate_func = error_rate . cer if cfg . error_rate_type == ' cer ' else error_rate . wer
vocab_list = self . test_loader . dataset . vocab_list
text_feature = self . test_loader . dataset . text_feature
target_transcripts = self . ordid2token ( texts , texts_len )
target_transcripts = self . ordid2token ( texts , texts_len )
result_transcripts = self . model . decode (
result_transcripts = self . model . decode (
audio ,
audio ,
audio_len ,
audio_len ,
vocab_list ,
text_feature= text_feature ,
decoding_method = cfg . decoding_method ,
decoding_method = cfg . decoding_method ,
lang_model_path = cfg . lang_model_path ,
lang_model_path = cfg . lang_model_path ,
beam_alpha = cfg . alpha ,
beam_alpha = cfg . alpha ,
@ -299,13 +308,19 @@ class U2Tester(U2Trainer):
beam_size = cfg . beam_size ,
beam_size = cfg . beam_size ,
cutoff_prob = cfg . cutoff_prob ,
cutoff_prob = cfg . cutoff_prob ,
cutoff_top_n = cfg . cutoff_top_n ,
cutoff_top_n = cfg . cutoff_top_n ,
num_processes = cfg . num_proc_bsearch )
num_processes = cfg . num_proc_bsearch ,
ctc_weight = cfg . ctc_weight ,
decoding_chunk_size = cfg . decoding_chunk_size ,
num_decoding_left_chunks = cfg . num_decoding_left_chunks ,
simulate_streaming = cfg . simulate_streaming )
for target , result in zip ( target_transcripts , result_transcripts ) :
for target , result in zip ( target_transcripts , result_transcripts ) :
errors , len_ref = errors_func ( target , result )
errors , len_ref = errors_func ( target , result )
errors_sum + = errors
errors_sum + = errors
len_refs + = len_ref
len_refs + = len_ref
num_ins + = 1
num_ins + = 1
if fout :
fout . write ( result + " \n " )
self . logger . info (
self . logger . info (
" \n Target Transcription: %s \n Output Transcription: %s " %
" \n Target Transcription: %s \n Output Transcription: %s " %
( target , result ) )
( target , result ) )
@ -322,6 +337,7 @@ class U2Tester(U2Trainer):
@mp_tools.rank_zero_only
@mp_tools.rank_zero_only
@paddle.no_grad ( )
@paddle.no_grad ( )
def test ( self ) :
def test ( self ) :
assert self . args . result_file
self . model . eval ( )
self . model . eval ( )
self . logger . info (
self . logger . info (
f " Test Total Examples: { len ( self . test_loader . dataset ) } " )
f " Test Total Examples: { len ( self . test_loader . dataset ) } " )
@ -329,14 +345,16 @@ class U2Tester(U2Trainer):
error_rate_type = None
error_rate_type = None
errors_sum , len_refs , num_ins = 0.0 , 0 , 0
errors_sum , len_refs , num_ins = 0.0 , 0 , 0
for i , batch in enumerate ( self . test_loader ) :
with open ( self . args . result_file , ' w ' ) as fout :
metrics = self . compute_metrics ( * batch )
for i , batch in enumerate ( self . test_loader ) :
errors_sum + = metrics [ ' errors_sum ' ]
metrics = self . compute_metrics ( * batch , fout = fout )
len_refs + = metrics [ ' len_refs ' ]
errors_sum + = metrics [ ' errors_sum ' ]
num_ins + = metrics [ ' num_ins ' ]
len_refs + = metrics [ ' len_refs ' ]
error_rate_type = metrics [ ' error_rate_type ' ]
num_ins + = metrics [ ' num_ins ' ]
self . logger . info ( " Error rate [ %s ] ( %d /?) = %f " %
error_rate_type = metrics [ ' error_rate_type ' ]
( error_rate_type , num_ins , errors_sum / len_refs ) )
self . logger . info (
" Error rate [ %s ] ( %d /?) = %f " %
( error_rate_type , num_ins , errors_sum / len_refs ) )
# logging
# logging
msg = " Test: "
msg = " Test: "
@ -351,24 +369,34 @@ class U2Tester(U2Trainer):
try :
try :
self . test ( )
self . test ( )
except KeyboardInterrupt :
except KeyboardInterrupt :
exit( - 1 )
sys. exit( - 1 )
def export ( self ) :
def load_inferspec ( self ) :
""" infer model and input spec.
Returns :
nn . Layer : inference model
List [ paddle . static . InputSpec ] : input spec .
"""
from deepspeech . models . u2 import U2InferModel
from deepspeech . models . u2 import U2InferModel
infer_model = U2InferModel . from_pretrained ( self . test_loader . dataset ,
infer_model = U2InferModel . from_pretrained ( self . test_loader . dataset ,
self . config . model . clone ( ) ,
self . config . model . clone ( ) ,
self . args . checkpoint_path )
self . args . checkpoint_path )
infer_model . eval ( )
feat_dim = self . test_loader . dataset . feature_size
feat_dim = self . test_loader . dataset . feature_size
static_model = paddle . jit . to_static (
input_spec = [
infer_model ,
paddle . static . InputSpec (
input_spec = [
shape = [ None , feat_dim , None ] ,
paddle . static . InputSpec (
dtype = ' float32 ' ) , # audio, [B,D,T]
shape = [ None , feat_dim , None ] ,
paddle . static . InputSpec ( shape = [ None ] ,
dtype = ' float32 ' ) , # audio, [B,D,T]
dtype = ' int64 ' ) , # audio_length, [B]
paddle . static . InputSpec ( shape = [ None ] ,
]
dtype = ' int64 ' ) , # audio_length, [B]
return infer_model , input_spec
] )
def export ( self ) :
infer_model , input_spec = self . load_inferspec ( )
assert isinstance ( input_spec , list ) , type ( input_spec )
infer_model . eval ( )
static_model = paddle . jit . to_static ( infer_model , input_spec = input_spec )
logger . info ( f " Export code: { static_model . forward . code } " )
logger . info ( f " Export code: { static_model . forward . code } " )
paddle . jit . save ( static_model , self . args . export_path )
paddle . jit . save ( static_model , self . args . export_path )
@ -376,7 +404,7 @@ class U2Tester(U2Trainer):
try :
try :
self . export ( )
self . export ( )
except KeyboardInterrupt :
except KeyboardInterrupt :
exit( - 1 )
sys. exit( - 1 )
def setup ( self ) :
def setup ( self ) :
""" Setup the experiment.
""" Setup the experiment.