diff --git a/tools/tune.py b/tools/tune.py index 966029a8..47abf141 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -88,7 +88,8 @@ def tune(): augmentation_config='{}', specgram_type=args.specgram_type, num_threads=args.num_proc_data, - keep_transcription_text=True) + keep_transcription_text=True, + num_conv_layers=args.num_conv_layers) audio_data = paddle.layer.data( name="audio_spectrogram", @@ -96,10 +97,25 @@ def tune(): text_data = paddle.layer.data( name="transcript_text", type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) + seq_offset_data = paddle.layer.data( + name='sequence_offset', + type=paddle.data_type.integer_value_sequence(1)) + seq_len_data = paddle.layer.data( + name='sequence_length', + type=paddle.data_type.integer_value_sequence(1)) + index_range_datas = [] + for i in xrange(args.num_rnn_layers): + index_range_datas.append( + paddle.layer.data( + name='conv%d_index_range' % i, + type=paddle.data_type.dense_vector(6))) output_probs, _ = deep_speech_v2_network( audio_data=audio_data, text_data=text_data, + seq_offset_data=seq_offset_data, + seq_len_data=seq_len_data, + index_range_datas=index_range_datas, dict_size=data_generator.vocab_size, num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, @@ -156,15 +172,17 @@ def tune(): for infer_data in batch_reader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): break - infer_results = inferer.infer(input=infer_data) - - num_steps = len(infer_results) // len(infer_data) + infer_results = inferer.infer(input=infer_data, + feeding=data_generator.feeding) + start_pos = [0] * (len(infer_data) + 1) + for i in xrange(len(infer_data)): + start_pos[i + 1] = start_pos[i] + infer_data[i][3][0] probs_split = [ - infer_results[i * num_steps:(i + 1) * num_steps] - for i in xrange(len(infer_data)) + infer_results[start_pos[i]:start_pos[i + 1]] + for i in xrange(0, len(infer_data)) ] - target_transcripts = [transcript for _, transcript in infer_data] + target_transcripts = [ data[1] for data in infer_data ] num_ins += len(target_transcripts) # grid search