@ -16,6 +16,7 @@ import json
import os
import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from typing import Optional
@ -23,21 +24,18 @@ import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle . io import DataLoader
from yacs . config import CfgNode
from paddlespeech . s2t . io . collator import SpeechCollator
from paddlespeech . s2t . io . collator import TripletSpeechCollator
from paddlespeech . s2t . io . dataset import ManifestDataset
from paddlespeech . s2t . io . sampler import SortagradBatchSampler
from paddlespeech . s2t . io . sampler import SortagradDistributedBatchSampler
from paddlespeech . s2t . frontend . featurizer import TextFeaturizer
from paddlespeech . s2t . io . dataloader import BatchDataLoader
from paddlespeech . s2t . models . u2_st import U2STModel
from paddlespeech . s2t . training . gradclip import ClipGradByGlobalNormWithLog
from paddlespeech . s2t . training . scheduler import WarmupLR
from paddlespeech . s2t . training . optimizer import OptimizerFactory
from paddlespeech . s2t . training . reporter import ObsScope
from paddlespeech . s2t . training . reporter import report
from paddlespeech . s2t . training . scheduler import LRSchedulerFactory
from paddlespeech . s2t . training . timer import Timer
from paddlespeech . s2t . training . trainer import Trainer
from paddlespeech . s2t . utils import bleu_score
from paddlespeech . s2t . utils import ctc_utils
from paddlespeech . s2t . utils import layer_tools
from paddlespeech . s2t . utils import mp_tools
from paddlespeech . s2t . utils . log import Log
@ -96,6 +94,8 @@ class U2STTrainer(Trainer):
# loss div by `batch_size * accum_grad`
loss / = train_conf . accum_grad
losses_np = { ' loss ' : float ( loss ) * train_conf . accum_grad }
if st_loss :
losses_np [ ' st_loss ' ] = float ( st_loss )
if attention_loss :
losses_np [ ' att_loss ' ] = float ( attention_loss )
if ctc_loss :
@ -125,6 +125,12 @@ class U2STTrainer(Trainer):
iteration_time = time . time ( ) - start
for k , v in losses_np . items ( ) :
report ( k , v )
report ( " batch_size " , self . config . collator . batch_size )
report ( " accum " , train_conf . accum_grad )
report ( " step_cost " , iteration_time )
if ( batch_index + 1 ) % train_conf . log_interval == 0 :
msg + = " train time: {:>.3f} s, " . format ( iteration_time )
msg + = " batch size: {} , " . format ( self . config . batch_size )
@ -204,16 +210,34 @@ class U2STTrainer(Trainer):
data_start_time = time . time ( )
for batch_index , batch in enumerate ( self . train_loader ) :
dataload_time = time . time ( ) - data_start_time
msg = " Train: Rank: {} , " . format ( dist . get_rank ( ) )
msg + = " epoch: {} , " . format ( self . epoch )
msg + = " step: {} , " . format ( self . iteration )
msg + = " batch : {} / {} , " . format ( batch_index + 1 ,
len ( self . train_loader ) )
msg + = " lr: {:>.8f} , " . format ( self . lr_scheduler ( ) )
msg + = " data time: {:>.3f} s, " . format ( dataload_time )
self . train_batch ( batch_index , batch , msg )
self . after_train_batch ( )
data_start_time = time . time ( )
msg = " Train: "
observation = OrderedDict ( )
with ObsScope ( observation ) :
report ( " Rank " , dist . get_rank ( ) )
report ( " epoch " , self . epoch )
report ( ' step ' , self . iteration )
report ( " lr " , self . lr_scheduler ( ) )
self . train_batch ( batch_index , batch , msg )
self . after_train_batch ( )
report ( ' iter ' , batch_index + 1 )
report ( ' total ' , len ( self . train_loader ) )
report ( ' reader_cost ' , dataload_time )
observation [ ' batch_cost ' ] = observation [
' reader_cost ' ] + observation [ ' step_cost ' ]
observation [ ' samples ' ] = observation [ ' batch_size ' ]
observation [ ' ips,sent./sec ' ] = observation [
' batch_size ' ] / observation [ ' batch_cost ' ]
for k , v in observation . items ( ) :
msg + = f " { k . split ( ' , ' ) [ 0 ] } : "
msg + = f " { v : >.8f } " if isinstance ( v ,
float ) else f " { v } "
msg + = f " { k . split ( ' , ' ) [ 1 ] } " if len (
k . split ( ' , ' ) ) == 2 else " "
msg + = " , "
msg = msg [ : - 1 ] # remove the last ","
if ( batch_index + 1
) % self . config . training . log_interval == 0 :
logger . info ( msg )
except Exception as e :
logger . error ( e )
raise e
@ -244,97 +268,88 @@ class U2STTrainer(Trainer):
def setup_dataloader ( self ) :
config = self . config . clone ( )
config . defrost ( )
config . keep_transcription_text = False
# train/valid dataset, return token ids
config . manifest = config . train_manifest
train_dataset = ManifestDataset . from_config ( config )
config . manifest = config . dev_manifest
dev_dataset = ManifestDataset . from_config ( config )
if config . model_conf . asr_weight > 0. :
Collator = TripletSpeechCollator
TestCollator = SpeechCollator
else :
TestCollator = Collator = SpeechCollator
collate_fn_train = Collator . from_config ( config )
config . augmentation_config = " "
collate_fn_dev = Collator . from_config ( config )
load_transcript = True if config . model_conf . asr_weight > 0 else False
if self . parallel :
batch_sampler = SortagradDistributedBatchSampler (
train_dataset ,
if self . train :
# train/valid dataset, return token ids
self . train_loader = BatchDataLoader (
json_file = config . train_manifest ,
train_mode = True ,
sortagrad = False ,
batch_size = config . batch_size ,
num_replicas = None ,
rank = None ,
shuffle = True ,
drop_last = True ,
sortagrad = config . sortagrad ,
shuffle_method = config . shuffle_method )
else :
batch_sampler = SortagradBatchSampler (
train_dataset ,
shuffle = True ,
maxlen_in = config . maxlen_in ,
maxlen_out = config . maxlen_out ,
minibatches = 0 ,
mini_batch_size = 1 ,
batch_count = ' auto ' ,
batch_bins = 0 ,
batch_frames_in = 0 ,
batch_frames_out = 0 ,
batch_frames_inout = 0 ,
preprocess_conf = config . augmentation_config , # aug will be off when train_mode=False
n_iter_processes = config . num_workers ,
subsampling_factor = 1 ,
load_aux_output = load_transcript ,
num_encs = 1 ,
dist_sampler = True )
self . valid_loader = BatchDataLoader (
json_file = config . dev_manifest ,
train_mode = False ,
sortagrad = False ,
batch_size = config . batch_size ,
drop_last = True ,
sortagrad = config . sortagrad ,
shuffle_method = config . shuffle_method )
self . train_loader = DataLoader (
train_dataset ,
batch_sampler = batch_sampler ,
collate_fn = collate_fn_train ,
num_workers = config . num_workers , )
self . valid_loader = DataLoader (
dev_dataset ,
batch_size = config . batch_size ,
shuffle = False ,
drop_last = False ,
collate_fn = collate_fn_dev ,
num_workers = config . num_workers , )
# test dataset, return raw text
config . manifest = config . test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
# config.min_input_len = 0.0 # second
# config.max_input_len = float('inf') # second
# config.min_output_len = 0.0 # tokens
# config.max_output_len = float('inf') # tokens
# config.min_output_input_ratio = 0.00
# config.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset . from_config ( config )
# return text ord id
config . keep_transcription_text = True
config . augmentation_config = " "
decode_batch_size = config . get ( ' decode ' , dict ( ) ) . get (
' decode_batch_size ' , 1 )
self . test_loader = DataLoader (
test_dataset ,
batch_size = decode_batch_size ,
shuffle = False ,
drop_last = False ,
collate_fn = TestCollator . from_config ( config ) ,
num_workers = config . num_workers , )
# return text token id
config . keep_transcription_text = False
self . align_loader = DataLoader (
test_dataset ,
batch_size = decode_batch_size ,
shuffle = False ,
drop_last = False ,
collate_fn = TestCollator . from_config ( config ) ,
num_workers = config . num_workers , )
logger . info ( " Setup train/valid/test/align Dataloader! " )
maxlen_in = float ( ' inf ' ) ,
maxlen_out = float ( ' inf ' ) ,
minibatches = 0 ,
mini_batch_size = 1 ,
batch_count = ' auto ' ,
batch_bins = 0 ,
batch_frames_in = 0 ,
batch_frames_out = 0 ,
batch_frames_inout = 0 ,
preprocess_conf = config . augmentation_config , # aug will be off when train_mode=False
n_iter_processes = config . num_workers ,
subsampling_factor = 1 ,
load_aux_output = load_transcript ,
num_encs = 1 ,
dist_sampler = True )
logger . info ( " Setup train/valid Dataloader! " )
else :
# test dataset, return raw text
decode_batch_size = config . get ( ' decode ' , dict ( ) ) . get ( ' decode_batch_size ' , 1 )
self . test_loader = BatchDataLoader (
json_file = config . data . test_manifest ,
train_mode = False ,
sortagrad = False ,
batch_size = decode_batch_size ,
maxlen_in = float ( ' inf ' ) ,
maxlen_out = float ( ' inf ' ) ,
minibatches = 0 ,
mini_batch_size = 1 ,
batch_count = ' auto ' ,
batch_bins = 0 ,
batch_frames_in = 0 ,
batch_frames_out = 0 ,
batch_frames_inout = 0 ,
preprocess_conf = config . augmentation_config , # aug will be off when train_mode=False
n_iter_processes = config . num_workers ,
subsampling_factor = 1 ,
num_encs = 1 ,
dist_sampler = False )
logger . info ( " Setup test Dataloader! " )
def setup_model ( self ) :
config = self . config
model_conf = config
with UpdateConfig ( model_conf ) :
model_conf . input_dim = self . train_loader . collate_fn . feature_size
model_conf . output_dim = self . train_loader . collate_fn . vocab_size
if self . train :
model_conf . input_dim = self . train_loader . feat_dim
model_conf . output_dim = self . train_loader . vocab_size
else :
model_conf . input_dim = self . test_loader . feat_dim
model_conf . output_dim = self . test_loader . vocab_size
model = U2STModel . from_config ( model_conf )
@ -350,35 +365,38 @@ class U2STTrainer(Trainer):
scheduler_type = train_config . scheduler
scheduler_conf = train_config . scheduler_conf
if scheduler_type == ' expdecaylr ' :
lr_scheduler = paddle . optimizer . lr . ExponentialDecay (
learning_rate = optim_conf . lr ,
gamma = scheduler_conf . lr_decay ,
verbose = False )
elif scheduler_type == ' warmuplr ' :
lr_scheduler = WarmupLR (
learning_rate = optim_conf . lr ,
warmup_steps = scheduler_conf . warmup_steps ,
verbose = False )
elif scheduler_type == ' noam ' :
lr_scheduler = paddle . optimizer . lr . NoamDecay (
learning_rate = optim_conf . lr ,
d_model = model_conf . encoder_conf . output_size ,
warmup_steps = scheduler_conf . warmup_steps ,
verbose = False )
else :
raise ValueError ( f " Not support scheduler: { scheduler_type } " )
grad_clip = ClipGradByGlobalNormWithLog ( train_config . global_grad_clip )
weight_decay = paddle . regularizer . L2Decay ( optim_conf . weight_decay )
if optim_type == ' adam ' :
optimizer = paddle . optimizer . Adam (
learning_rate = lr_scheduler ,
parameters = model . parameters ( ) ,
weight_decay = weight_decay ,
grad_clip = grad_clip )
else :
raise ValueError ( f " Not support optim: { optim_type } " )
scheduler_args = {
" learning_rate " : optim_conf . lr ,
" verbose " : False ,
" warmup_steps " : scheduler_conf . warmup_steps ,
" gamma " : scheduler_conf . lr_decay ,
" d_model " : model_conf . encoder_conf . output_size ,
}
lr_scheduler = LRSchedulerFactory . from_args ( scheduler_type ,
scheduler_args )
def optimizer_args (
config ,
parameters ,
lr_scheduler = None , ) :
train_config = config . training
optim_type = train_config . optim
optim_conf = train_config . optim_conf
scheduler_type = train_config . scheduler
scheduler_conf = train_config . scheduler_conf
return {
" grad_clip " : train_config . global_grad_clip ,
" weight_decay " : optim_conf . weight_decay ,
" learning_rate " : lr_scheduler
if lr_scheduler else optim_conf . lr ,
" parameters " : parameters ,
" epsilon " : 1e-9 if optim_type == ' noam ' else None ,
" beta1 " : 0.9 if optim_type == ' noam ' else None ,
" beat2 " : 0.98 if optim_type == ' noam ' else None ,
}
optimzer_args = optimizer_args ( config , model . parameters ( ) , lr_scheduler )
optimizer = OptimizerFactory . from_args ( optim_type , optimzer_args )
self . model = model
self . optimizer = optimizer
@ -418,26 +436,30 @@ class U2STTester(U2STTrainer):
def __init__ ( self , config , args ) :
super ( ) . __init__ ( config , args )
self . text_feature = TextFeaturizer (
unit_type = self . config . collator . unit_type ,
vocab_filepath = self . config . collator . vocab_filepath ,
spm_model_prefix = self . config . collator . spm_model_prefix )
self . vocab_list = self . text_feature . vocab_list
def ordid2token ( self , texts , texts_len ) :
def id2token( self , texts , texts_len , text_feature ) :
""" ord() id to chr() chr """
trans = [ ]
for text , n in zip ( texts , texts_len ) :
n = n . numpy ( ) . item ( )
ids = text [ : n ]
trans . append ( ' ' . join ( [ chr ( i ) for i in ids ] ) )
trans . append ( text_feature . defeaturize ( ids . numpy ( ) . tolist ( ) ) )
return trans
def translate ( self , audio , audio_len ) :
""" " E2E translation from extracted audio feature """
decode_cfg = self . config . decode
text_feature = self . test_loader . collate_fn . text_feature
self . model . eval ( )
hyps = self . model . decode (
audio ,
audio_len ,
text_feature = text_feature ,
text_feature = self . text_feature ,
decoding_method = decode_cfg . decoding_method ,
beam_size = decode_cfg . beam_size ,
word_reward = decode_cfg . word_reward ,
@ -458,23 +480,20 @@ class U2STTester(U2STTrainer):
len_refs , num_ins = 0 , 0
start_time = time . time ( )
text_feature = self . test_loader . collate_fn . text_feature
refs = [
" " . join ( chr ( t ) for t in text [ : text_len ] )
for text , text_len in zip ( texts , texts_len )
]
refs = self . id2token ( texts , texts_len , self . text_feature )
hyps = self . model . decode (
audio ,
audio_len ,
text_feature = text_feature ,
text_feature = self . text_feature ,
decoding_method = decode_cfg . decoding_method ,
beam_size = decode_cfg . beam_size ,
word_reward = decode_cfg . word_reward ,
decoding_chunk_size = decode_cfg . decoding_chunk_size ,
num_decoding_left_chunks = decode_cfg . num_decoding_left_chunks ,
simulate_streaming = decode_cfg . simulate_streaming )
decode_time = time . time ( ) - start_time
for utt , target , result in zip ( utts , refs , hyps ) :
@ -507,7 +526,7 @@ class U2STTester(U2STTrainer):
decode_cfg = self . config . decode
bleu_func = bleu_score . char_bleu if decode_cfg . error_rate_type == ' char-bleu ' else bleu_score . bleu
stride_ms = self . test_loader. collate_fn . stride_ms
stride_ms = self . config. collator . stride_ms
hyps , refs = [ ] , [ ]
len_refs , num_ins = 0 , 0
num_frames = 0.0
@ -524,7 +543,8 @@ class U2STTester(U2STTrainer):
len_refs + = metrics [ ' len_refs ' ]
num_ins + = metrics [ ' num_ins ' ]
rtf = num_time / ( num_frames * stride_ms )
logger . info ( " RTF: %f , BELU ( %d ) = %f " % ( rtf , num_ins , bleu ) )
logger . info ( " RTF: %f , instance ( %d ), batch BELU = %f " %
( rtf , num_ins , bleu ) )
rtf = num_time / ( num_frames * stride_ms )
msg = " Test: "
@ -555,13 +575,6 @@ class U2STTester(U2STTrainer):
} )
f . write ( data + ' \n ' )
@paddle.no_grad ( )
def align ( self ) :
ctc_utils . ctc_align ( self . config , self . model , self . align_loader ,
self . config . decode . decode_batch_size ,
self . config . stride_ms , self . vocab_list ,
self . args . result_file )
def load_inferspec ( self ) :
""" infer model and input spec.
@ -569,11 +582,11 @@ class U2STTester(U2STTrainer):
nn . Layer : inference model
List [ paddle . static . InputSpec ] : input spec .
"""
from paddlespeech . s2t . models . u2 import U2 InferModel
infer_model = U2 InferModel. from_pretrained ( self . test_loader ,
self . config . clone ( ) ,
self . args . checkpoint_path )
feat_dim = self . test_loader . collate_fn. feature_size
from paddlespeech . s2t . models . u2 _st import U2 ST InferModel
infer_model = U2 ST InferModel. from_pretrained ( self . test_loader ,
self . config . clone ( ) ,
self . args . checkpoint_path )
feat_dim = self . test_loader . feat_dim
input_spec = [
paddle . static . InputSpec ( shape = [ 1 , None , feat_dim ] ,
dtype = ' float32 ' ) , # audio, [B,T,D]