@ -16,6 +16,7 @@ import json
import os
import os
import time
import time
from collections import defaultdict
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from contextlib import nullcontext
from typing import Optional
from typing import Optional
@ -23,21 +24,18 @@ import jsonlines
import numpy as np
import numpy as np
import paddle
import paddle
from paddle import distributed as dist
from paddle import distributed as dist
from paddle . io import DataLoader
from yacs . config import CfgNode
from yacs . config import CfgNode
from paddlespeech . s2t . io . collator import SpeechCollator
from paddlespeech . s2t . frontend . featurizer import TextFeaturizer
from paddlespeech . s2t . io . collator import TripletSpeechCollator
from paddlespeech . s2t . io . dataloader import BatchDataLoader
from paddlespeech . s2t . io . dataset import ManifestDataset
from paddlespeech . s2t . io . sampler import SortagradBatchSampler
from paddlespeech . s2t . io . sampler import SortagradDistributedBatchSampler
from paddlespeech . s2t . models . u2_st import U2STModel
from paddlespeech . s2t . models . u2_st import U2STModel
from paddlespeech . s2t . training . gradclip import ClipGradByGlobalNormWithLog
from paddlespeech . s2t . training . optimizer import OptimizerFactory
from paddlespeech . s2t . training . scheduler import WarmupLR
from paddlespeech . s2t . training . reporter import ObsScope
from paddlespeech . s2t . training . reporter import report
from paddlespeech . s2t . training . scheduler import LRSchedulerFactory
from paddlespeech . s2t . training . timer import Timer
from paddlespeech . s2t . training . timer import Timer
from paddlespeech . s2t . training . trainer import Trainer
from paddlespeech . s2t . training . trainer import Trainer
from paddlespeech . s2t . utils import bleu_score
from paddlespeech . s2t . utils import bleu_score
from paddlespeech . s2t . utils import ctc_utils
from paddlespeech . s2t . utils import layer_tools
from paddlespeech . s2t . utils import layer_tools
from paddlespeech . s2t . utils import mp_tools
from paddlespeech . s2t . utils import mp_tools
from paddlespeech . s2t . utils . log import Log
from paddlespeech . s2t . utils . log import Log
@ -96,6 +94,8 @@ class U2STTrainer(Trainer):
# loss div by `batch_size * accum_grad`
# loss div by `batch_size * accum_grad`
loss / = train_conf . accum_grad
loss / = train_conf . accum_grad
losses_np = { ' loss ' : float ( loss ) * train_conf . accum_grad }
losses_np = { ' loss ' : float ( loss ) * train_conf . accum_grad }
if st_loss :
losses_np [ ' st_loss ' ] = float ( st_loss )
if attention_loss :
if attention_loss :
losses_np [ ' att_loss ' ] = float ( attention_loss )
losses_np [ ' att_loss ' ] = float ( attention_loss )
if ctc_loss :
if ctc_loss :
@ -125,6 +125,12 @@ class U2STTrainer(Trainer):
iteration_time = time . time ( ) - start
iteration_time = time . time ( ) - start
for k , v in losses_np . items ( ) :
report ( k , v )
report ( " batch_size " , self . config . collator . batch_size )
report ( " accum " , train_conf . accum_grad )
report ( " step_cost " , iteration_time )
if ( batch_index + 1 ) % train_conf . log_interval == 0 :
if ( batch_index + 1 ) % train_conf . log_interval == 0 :
msg + = " train time: {:>.3f} s, " . format ( iteration_time )
msg + = " train time: {:>.3f} s, " . format ( iteration_time )
msg + = " batch size: {} , " . format ( self . config . batch_size )
msg + = " batch size: {} , " . format ( self . config . batch_size )
@ -204,16 +210,34 @@ class U2STTrainer(Trainer):
data_start_time = time . time ( )
data_start_time = time . time ( )
for batch_index , batch in enumerate ( self . train_loader ) :
for batch_index , batch in enumerate ( self . train_loader ) :
dataload_time = time . time ( ) - data_start_time
dataload_time = time . time ( ) - data_start_time
msg = " Train: Rank: {} , " . format ( dist . get_rank ( ) )
msg = " Train: "
msg + = " epoch: {} , " . format ( self . epoch )
observation = OrderedDict ( )
msg + = " step: {} , " . format ( self . iteration )
with ObsScope ( observation ) :
msg + = " batch : {} / {} , " . format ( batch_index + 1 ,
report ( " Rank " , dist . get_rank ( ) )
len ( self . train_loader ) )
report ( " epoch " , self . epoch )
msg + = " lr: {:>.8f} , " . format ( self . lr_scheduler ( ) )
report ( ' step ' , self . iteration )
msg + = " data time: {:>.3f} s, " . format ( dataload_time )
report ( " lr " , self . lr_scheduler ( ) )
self . train_batch ( batch_index , batch , msg )
self . train_batch ( batch_index , batch , msg )
self . after_train_batch ( )
self . after_train_batch ( )
data_start_time = time . time ( )
report ( ' iter ' , batch_index + 1 )
report ( ' total ' , len ( self . train_loader ) )
report ( ' reader_cost ' , dataload_time )
observation [ ' batch_cost ' ] = observation [
' reader_cost ' ] + observation [ ' step_cost ' ]
observation [ ' samples ' ] = observation [ ' batch_size ' ]
observation [ ' ips,sent./sec ' ] = observation [
' batch_size ' ] / observation [ ' batch_cost ' ]
for k , v in observation . items ( ) :
msg + = f " { k . split ( ' , ' ) [ 0 ] } : "
msg + = f " { v : >.8f } " if isinstance ( v ,
float ) else f " { v } "
msg + = f " { k . split ( ' , ' ) [ 1 ] } " if len (
k . split ( ' , ' ) ) == 2 else " "
msg + = " , "
msg = msg [ : - 1 ] # remove the last ","
if ( batch_index + 1
) % self . config . training . log_interval == 0 :
logger . info ( msg )
except Exception as e :
except Exception as e :
logger . error ( e )
logger . error ( e )
raise e
raise e
@ -244,97 +268,88 @@ class U2STTrainer(Trainer):
def setup_dataloader ( self ) :
def setup_dataloader ( self ) :
config = self . config . clone ( )
config = self . config . clone ( )
config . defrost ( )
config . keep_transcription_text = False
# train/valid dataset, return token ids
config . manifest = config . train_manifest
train_dataset = ManifestDataset . from_config ( config )
config . manifest = config . dev_manifest
dev_dataset = ManifestDataset . from_config ( config )
if config . model_conf . asr_weight > 0. :
Collator = TripletSpeechCollator
TestCollator = SpeechCollator
else :
TestCollator = Collator = SpeechCollator
collate_fn_train = Collator . from_config ( config )
load_transcript = True if config . model_conf . asr_weight > 0 else False
config . augmentation_config = " "
collate_fn_dev = Collator . from_config ( config )
if self . parallel :
if self . train :
batch_sampler = SortagradDistributedBatchSampler (
# train/valid dataset, return token ids
train_dataset ,
self . train_loader = BatchDataLoader (
json_file = config . train_manifest ,
train_mode = True ,
sortagrad = False ,
batch_size = config . batch_size ,
batch_size = config . batch_size ,
num_replicas = None ,
maxlen_in = config . maxlen_in ,
rank = None ,
maxlen_out = config . maxlen_out ,
shuffle = True ,
minibatches = 0 ,
drop_last = True ,
mini_batch_size = 1 ,
sortagrad = config . sortagrad ,
batch_count = ' auto ' ,
shuffle_method = config . shuffle_method )
batch_bins = 0 ,
else :
batch_frames_in = 0 ,
batch_sampler = SortagradBatchSampler (
batch_frames_out = 0 ,
train_dataset ,
batch_frames_inout = 0 ,
shuffle = True ,
preprocess_conf = config . augmentation_config , # aug will be off when train_mode=False
n_iter_processes = config . num_workers ,
subsampling_factor = 1 ,
load_aux_output = load_transcript ,
num_encs = 1 ,
dist_sampler = True )
self . valid_loader = BatchDataLoader (
json_file = config . dev_manifest ,
train_mode = False ,
sortagrad = False ,
batch_size = config . batch_size ,
batch_size = config . batch_size ,
drop_last = True ,
maxlen_in = float ( ' inf ' ) ,
sortagrad = config . sortagrad ,
maxlen_out = float ( ' inf ' ) ,
shuffle_method = config . shuffle_method )
minibatches = 0 ,
self . train_loader = DataLoader (
mini_batch_size = 1 ,
train_dataset ,
batch_count = ' auto ' ,
batch_sampler = batch_sampler ,
batch_bins = 0 ,
collate_fn = collate_fn_train ,
batch_frames_in = 0 ,
num_workers = config . num_workers , )
batch_frames_out = 0 ,
self . valid_loader = DataLoader (
batch_frames_inout = 0 ,
dev_dataset ,
preprocess_conf = config . augmentation_config , # aug will be off when train_mode=False
batch_size = config . batch_size ,
n_iter_processes = config . num_workers ,
shuffle = False ,
subsampling_factor = 1 ,
drop_last = False ,
load_aux_output = load_transcript ,
collate_fn = collate_fn_dev ,
num_encs = 1 ,
num_workers = config . num_workers , )
dist_sampler = True )
logger . info ( " Setup train/valid Dataloader! " )
# test dataset, return raw text
else :
config . manifest = config . test_manifest
# test dataset, return raw text
# filter test examples, will cause less examples, but no mismatch with training
decode_batch_size = config . get ( ' decode ' , dict ( ) ) . get ( ' decode_batch_size ' , 1 )
# and can use large batch size , save training time, so filter test egs now.
self . test_loader = BatchDataLoader (
# config.min_input_len = 0.0 # second
json_file = config . data . test_manifest ,
# config.max_input_len = float('inf') # second
train_mode = False ,
# config.min_output_len = 0.0 # tokens
sortagrad = False ,
# config.max_output_len = float('inf') # tokens
batch_size = decode_batch_size ,
# config.min_output_input_ratio = 0.00
maxlen_in = float ( ' inf ' ) ,
# config.max_output_input_ratio = float('inf')
maxlen_out = float ( ' inf ' ) ,
test_dataset = ManifestDataset . from_config ( config )
minibatches = 0 ,
# return text ord id
mini_batch_size = 1 ,
config . keep_transcription_text = True
batch_count = ' auto ' ,
config . augmentation_config = " "
batch_bins = 0 ,
decode_batch_size = config . get ( ' decode ' , dict ( ) ) . get (
batch_frames_in = 0 ,
' decode_batch_size ' , 1 )
batch_frames_out = 0 ,
self . test_loader = DataLoader (
batch_frames_inout = 0 ,
test_dataset ,
preprocess_conf = config . augmentation_config , # aug will be off when train_mode=False
batch_size = decode_batch_size ,
n_iter_processes = config . num_workers ,
shuffle = False ,
subsampling_factor = 1 ,
drop_last = False ,
num_encs = 1 ,
collate_fn = TestCollator . from_config ( config ) ,
dist_sampler = False )
num_workers = config . num_workers , )
# return text token id
logger . info ( " Setup test Dataloader! " )
config . keep_transcription_text = False
self . align_loader = DataLoader (
test_dataset ,
batch_size = decode_batch_size ,
shuffle = False ,
drop_last = False ,
collate_fn = TestCollator . from_config ( config ) ,
num_workers = config . num_workers , )
logger . info ( " Setup train/valid/test/align Dataloader! " )
def setup_model ( self ) :
def setup_model ( self ) :
config = self . config
config = self . config
model_conf = config
model_conf = config
with UpdateConfig ( model_conf ) :
with UpdateConfig ( model_conf ) :
model_conf . input_dim = self . train_loader . collate_fn . feature_size
if self . train :
model_conf . output_dim = self . train_loader . collate_fn . vocab_size
model_conf . input_dim = self . train_loader . feat_dim
model_conf . output_dim = self . train_loader . vocab_size
else :
model_conf . input_dim = self . test_loader . feat_dim
model_conf . output_dim = self . test_loader . vocab_size
model = U2STModel . from_config ( model_conf )
model = U2STModel . from_config ( model_conf )
@ -350,35 +365,38 @@ class U2STTrainer(Trainer):
scheduler_type = train_config . scheduler
scheduler_type = train_config . scheduler
scheduler_conf = train_config . scheduler_conf
scheduler_conf = train_config . scheduler_conf
if scheduler_type == ' expdecaylr ' :
scheduler_args = {
lr_scheduler = paddle . optimizer . lr . ExponentialDecay (
" learning_rate " : optim_conf . lr ,
learning_rate = optim_conf . lr ,
" verbose " : False ,
gamma = scheduler_conf . lr_decay ,
" warmup_steps " : scheduler_conf . warmup_steps ,
verbose = False )
" gamma " : scheduler_conf . lr_decay ,
elif scheduler_type == ' warmuplr ' :
" d_model " : model_conf . encoder_conf . output_size ,
lr_scheduler = WarmupLR (
}
learning_rate = optim_conf . lr ,
lr_scheduler = LRSchedulerFactory . from_args ( scheduler_type ,
warmup_steps = scheduler_conf . warmup_steps ,
scheduler_args )
verbose = False )
elif scheduler_type == ' noam ' :
def optimizer_args (
lr_scheduler = paddle . optimizer . lr . NoamDecay (
config ,
learning_rate = optim_conf . lr ,
parameters ,
d_model = model_conf . encoder_conf . output_size ,
lr_scheduler = None , ) :
warmup_steps = scheduler_conf . warmup_steps ,
train_config = config . training
verbose = False )
optim_type = train_config . optim
else :
optim_conf = train_config . optim_conf
raise ValueError ( f " Not support scheduler: { scheduler_type } " )
scheduler_type = train_config . scheduler
scheduler_conf = train_config . scheduler_conf
grad_clip = ClipGradByGlobalNormWithLog ( train_config . global_grad_clip )
return {
weight_decay = paddle . regularizer . L2Decay ( optim_conf . weight_decay )
" grad_clip " : train_config . global_grad_clip ,
if optim_type == ' adam ' :
" weight_decay " : optim_conf . weight_decay ,
optimizer = paddle . optimizer . Adam (
" learning_rate " : lr_scheduler
learning_rate = lr_scheduler ,
if lr_scheduler else optim_conf . lr ,
parameters = model . parameters ( ) ,
" parameters " : parameters ,
weight_decay = weight_decay ,
" epsilon " : 1e-9 if optim_type == ' noam ' else None ,
grad_clip = grad_clip )
" beta1 " : 0.9 if optim_type == ' noam ' else None ,
else :
" beat2 " : 0.98 if optim_type == ' noam ' else None ,
raise ValueError ( f " Not support optim: { optim_type } " )
}
optimzer_args = optimizer_args ( config , model . parameters ( ) , lr_scheduler )
optimizer = OptimizerFactory . from_args ( optim_type , optimzer_args )
self . model = model
self . model = model
self . optimizer = optimizer
self . optimizer = optimizer
@ -418,26 +436,30 @@ class U2STTester(U2STTrainer):
def __init__ ( self , config , args ) :
def __init__ ( self , config , args ) :
super ( ) . __init__ ( config , args )
super ( ) . __init__ ( config , args )
self . text_feature = TextFeaturizer (
unit_type = self . config . collator . unit_type ,
vocab_filepath = self . config . collator . vocab_filepath ,
spm_model_prefix = self . config . collator . spm_model_prefix )
self . vocab_list = self . text_feature . vocab_list
def ordid2token ( self , texts , texts_len ) :
def id2token( self , texts , texts_len , text_feature ) :
""" ord() id to chr() chr """
""" ord() id to chr() chr """
trans = [ ]
trans = [ ]
for text , n in zip ( texts , texts_len ) :
for text , n in zip ( texts , texts_len ) :
n = n . numpy ( ) . item ( )
n = n . numpy ( ) . item ( )
ids = text [ : n ]
ids = text [ : n ]
trans . append ( ' ' . join ( [ chr ( i ) for i in ids ] ) )
trans . append ( text_feature . defeaturize ( ids . numpy ( ) . tolist ( ) ) )
return trans
return trans
def translate ( self , audio , audio_len ) :
def translate ( self , audio , audio_len ) :
""" " E2E translation from extracted audio feature """
""" " E2E translation from extracted audio feature """
decode_cfg = self . config . decode
decode_cfg = self . config . decode
text_feature = self . test_loader . collate_fn . text_feature
self . model . eval ( )
self . model . eval ( )
hyps = self . model . decode (
hyps = self . model . decode (
audio ,
audio ,
audio_len ,
audio_len ,
text_feature = text_feature ,
text_feature = self . text_feature ,
decoding_method = decode_cfg . decoding_method ,
decoding_method = decode_cfg . decoding_method ,
beam_size = decode_cfg . beam_size ,
beam_size = decode_cfg . beam_size ,
word_reward = decode_cfg . word_reward ,
word_reward = decode_cfg . word_reward ,
@ -458,23 +480,20 @@ class U2STTester(U2STTrainer):
len_refs , num_ins = 0 , 0
len_refs , num_ins = 0 , 0
start_time = time . time ( )
start_time = time . time ( )
text_feature = self . test_loader . collate_fn . text_feature
refs = [
refs = self . id2token ( texts , texts_len , self . text_feature )
" " . join ( chr ( t ) for t in text [ : text_len ] )
for text , text_len in zip ( texts , texts_len )
]
hyps = self . model . decode (
hyps = self . model . decode (
audio ,
audio ,
audio_len ,
audio_len ,
text_feature = text_feature ,
text_feature = self . text_feature ,
decoding_method = decode_cfg . decoding_method ,
decoding_method = decode_cfg . decoding_method ,
beam_size = decode_cfg . beam_size ,
beam_size = decode_cfg . beam_size ,
word_reward = decode_cfg . word_reward ,
word_reward = decode_cfg . word_reward ,
decoding_chunk_size = decode_cfg . decoding_chunk_size ,
decoding_chunk_size = decode_cfg . decoding_chunk_size ,
num_decoding_left_chunks = decode_cfg . num_decoding_left_chunks ,
num_decoding_left_chunks = decode_cfg . num_decoding_left_chunks ,
simulate_streaming = decode_cfg . simulate_streaming )
simulate_streaming = decode_cfg . simulate_streaming )
decode_time = time . time ( ) - start_time
decode_time = time . time ( ) - start_time
for utt , target , result in zip ( utts , refs , hyps ) :
for utt , target , result in zip ( utts , refs , hyps ) :
@ -507,7 +526,7 @@ class U2STTester(U2STTrainer):
decode_cfg = self . config . decode
decode_cfg = self . config . decode
bleu_func = bleu_score . char_bleu if decode_cfg . error_rate_type == ' char-bleu ' else bleu_score . bleu
bleu_func = bleu_score . char_bleu if decode_cfg . error_rate_type == ' char-bleu ' else bleu_score . bleu
stride_ms = self . test_loader. collate_fn . stride_ms
stride_ms = self . config. collator . stride_ms
hyps , refs = [ ] , [ ]
hyps , refs = [ ] , [ ]
len_refs , num_ins = 0 , 0
len_refs , num_ins = 0 , 0
num_frames = 0.0
num_frames = 0.0
@ -524,7 +543,8 @@ class U2STTester(U2STTrainer):
len_refs + = metrics [ ' len_refs ' ]
len_refs + = metrics [ ' len_refs ' ]
num_ins + = metrics [ ' num_ins ' ]
num_ins + = metrics [ ' num_ins ' ]
rtf = num_time / ( num_frames * stride_ms )
rtf = num_time / ( num_frames * stride_ms )
logger . info ( " RTF: %f , BELU ( %d ) = %f " % ( rtf , num_ins , bleu ) )
logger . info ( " RTF: %f , instance ( %d ), batch BELU = %f " %
( rtf , num_ins , bleu ) )
rtf = num_time / ( num_frames * stride_ms )
rtf = num_time / ( num_frames * stride_ms )
msg = " Test: "
msg = " Test: "
@ -555,13 +575,6 @@ class U2STTester(U2STTrainer):
} )
} )
f . write ( data + ' \n ' )
f . write ( data + ' \n ' )
@paddle.no_grad ( )
def align ( self ) :
ctc_utils . ctc_align ( self . config , self . model , self . align_loader ,
self . config . decode . decode_batch_size ,
self . config . stride_ms , self . vocab_list ,
self . args . result_file )
def load_inferspec ( self ) :
def load_inferspec ( self ) :
""" infer model and input spec.
""" infer model and input spec.
@ -569,11 +582,11 @@ class U2STTester(U2STTrainer):
nn . Layer : inference model
nn . Layer : inference model
List [ paddle . static . InputSpec ] : input spec .
List [ paddle . static . InputSpec ] : input spec .
"""
"""
from paddlespeech . s2t . models . u2 import U2 InferModel
from paddlespeech . s2t . models . u2 _st import U2 ST InferModel
infer_model = U2 InferModel. from_pretrained ( self . test_loader ,
infer_model = U2 ST InferModel. from_pretrained ( self . test_loader ,
self . config . clone ( ) ,
self . config . clone ( ) ,
self . args . checkpoint_path )
self . args . checkpoint_path )
feat_dim = self . test_loader . collate_fn. feature_size
feat_dim = self . test_loader . feat_dim
input_spec = [
input_spec = [
paddle . static . InputSpec ( shape = [ 1 , None , feat_dim ] ,
paddle . static . InputSpec ( shape = [ 1 , None , feat_dim ] ,
dtype = ' float32 ' ) , # audio, [B,T,D]
dtype = ' float32 ' ) , # audio, [B,T,D]