add result output

pull/657/head
Haoxin Ma 4 years ago
parent f3c9f32c9a
commit a58b1cb30a

@ -193,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, audio, audio_len, texts, texts_len):
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -215,11 +215,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(utt + " " + result + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("Current error rate [%s] = %f" %
@ -240,16 +242,16 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
for i, batch in enumerate(self.test_loader):
utt, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(audio, audio_len, texts, texts_len)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "

@ -76,8 +76,9 @@ class U2Trainer(Trainer):
def train_batch(self, batch_index, batch_data, msg):
train_conf = self.config.training
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(*batch_data)
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
@ -119,9 +120,10 @@ class U2Trainer(Trainer):
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
loss, attention_loss, ctc_loss = self.model(*batch)
utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
if paddle.isfinite(loss):
num_utts = batch[0].shape[0]
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
@ -366,7 +368,7 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, audio, audio_len, texts, texts_len, fout=None):
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None, fref=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -393,13 +395,15 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for target, result in zip(target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write(result + "\n")
fout.write(utt + " " + result + "\n")
if fref:
fref.write(utt + " " + target + "\n")
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
logger.info("One example error rate [%s] = %f" %
@ -428,6 +432,7 @@ class U2Tester(U2Trainer):
num_time = 0.0
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
# utt, audio, audio_len, text, text_len = batch
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]

@ -51,7 +51,10 @@ class SpeechCollator():
audio_lens = []
texts = []
text_lens = []
utts = []
for utt, audio, text in batch:
#utt
utts.append(utt)
# audio
audios.append(audio.T) # [T, D]
audio_lens.append(audio.shape[1])
@ -75,4 +78,4 @@ class SpeechCollator():
padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return utt, padded_audios, audio_lens, padded_texts, text_lens
return utts, padded_audios, audio_lens, padded_texts, text_lens

@ -905,6 +905,7 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict):
super().__init__(configs)
def forward(self,
feats,
feats_lengths,

@ -114,7 +114,8 @@ class ConvBn(nn.Layer):
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
masks = masks.type_as(x)
# masks = masks.type_as(x)
masks = masks.astype(x)
x = x.multiply(masks)
return x, x_len

@ -26,7 +26,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then

@ -8,7 +8,7 @@ data:
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
batch_size: 2 #4
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
@ -31,7 +31,7 @@ data:
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
num_workers: 0 #2
# network architecture
@ -70,7 +70,7 @@ model:
training:
n_epoch: 20
n_epoch: 2
accum_grad: 1
global_grad_clip: 5.0
optim: adam
@ -85,7 +85,7 @@ training:
decoding:
batch_size: 64
batch_size: 8 #64
error_rate_type: wer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm

@ -20,20 +20,22 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
./local/avg.sh exp/${ckpt}/checkpoints ${avg_num}
avg.sh exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
# CUDA_VISIBLE_DEVICES=7
./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
# CUDA_VISIBLE_DEVICES=
./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi

Loading…
Cancel
Save