|
|
|
@ -34,7 +34,6 @@ add_arg('beta_from', float, 0.1, "Where beta starts tuning from.")
|
|
|
|
|
add_arg('beta_to', float, 0.45, "Where beta ends tuning with.")
|
|
|
|
|
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
|
|
|
|
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
|
|
|
|
add_arg('output_fig', bool, True, "Output error rate figure or not.")
|
|
|
|
|
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
|
|
|
|
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
|
|
|
|
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
|
|
|
@ -66,26 +65,9 @@ add_arg('specgram_type', str,
|
|
|
|
|
# yapf: disable
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
def plot_error_surface(params_grid, err_ave, fig_name):
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import mpl_toolkits.mplot3d as Axes3D
|
|
|
|
|
fig = plt.figure()
|
|
|
|
|
ax = Axes3D(fig)
|
|
|
|
|
alphas = [ param[0] for param in params_grid ]
|
|
|
|
|
betas = [ param[1] for param in params_grid]
|
|
|
|
|
ALPHAS = np.reshape(alphas, (args.num_alphas, args.num_betas))
|
|
|
|
|
BETAS = np.reshape(betas, (args.num_alphas, args.num_betas))
|
|
|
|
|
ERR_AVE = np.reshape(err_ave, (args.num_alphas, args.num_betas))
|
|
|
|
|
ax.plot_surface(ALPHAS, BETAS, WERS,
|
|
|
|
|
rstride=1, cstride=1, alpha=0.8, cmap='rainbow')
|
|
|
|
|
ax.set_xlabel('alpha')
|
|
|
|
|
ax.set_ylabel('beta')
|
|
|
|
|
z_label = 'WER' if args.error_rate_type == 'wer' else 'CER'
|
|
|
|
|
ax.set_zlabel(z_label)
|
|
|
|
|
plt.savefig(fig_name)
|
|
|
|
|
|
|
|
|
|
def tune():
|
|
|
|
|
"""Tune parameters alpha and beta on one minibatch."""
|
|
|
|
|
"""Tune parameters alpha and beta incrementally."""
|
|
|
|
|
if not args.num_alphas >= 0:
|
|
|
|
|
raise ValueError("num_alphas must be non-negative!")
|
|
|
|
|
if not args.num_betas >= 0:
|
|
|
|
@ -160,38 +142,36 @@ def tune():
|
|
|
|
|
err_ave[index] = err_sum[index] / num_ins
|
|
|
|
|
# print("alpha = %f, beta = %f, WER = %f" %
|
|
|
|
|
# (alpha, beta, err_ave[index]))
|
|
|
|
|
if index % 10 == 0:
|
|
|
|
|
if index % 2 == 0:
|
|
|
|
|
sys.stdout.write('.')
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
# output on-line tuning result at the the end of current batch
|
|
|
|
|
err_ave_min = min(err_ave)
|
|
|
|
|
min_index = err_ave.index(err_ave_min)
|
|
|
|
|
print("\nBatch %d, opt.(alpha, beta) = (%f, %f), min. error_rate = %f"
|
|
|
|
|
%(cur_batch, params_grid[min_index][0],
|
|
|
|
|
params_grid[min_index][1], err_ave_min))
|
|
|
|
|
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
|
|
|
|
|
" min [%s] = %f" %(cur_batch, num_ins,
|
|
|
|
|
"%.3f" % params_grid[min_index][0],
|
|
|
|
|
"%.3f" % params_grid[min_index][1],
|
|
|
|
|
args.error_rate_type, err_ave_min))
|
|
|
|
|
cur_batch += 1
|
|
|
|
|
|
|
|
|
|
# output WER/CER at every point
|
|
|
|
|
print("\nerror rate at each point:\n")
|
|
|
|
|
print("\nFinal %s:\n" % args.error_rate_type)
|
|
|
|
|
for index in xrange(len(params_grid)):
|
|
|
|
|
print("(%f, %f), error_rate = %f"
|
|
|
|
|
% (params_grid[index][0], params_grid[index][1], err_ave[index]))
|
|
|
|
|
print("(alpha, beta) = (%s, %s), [%s] = %f"
|
|
|
|
|
% ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1],
|
|
|
|
|
args.error_rate_type, err_ave[index]))
|
|
|
|
|
|
|
|
|
|
err_ave_min = min(err_ave)
|
|
|
|
|
min_index = err_ave.index(err_ave_min)
|
|
|
|
|
print("\nTuning on %d batches, opt. (alpha, beta) = (%f, %f)"
|
|
|
|
|
% (args.num_batches, params_grid[min_index][0],
|
|
|
|
|
params_grid[min_index][1]))
|
|
|
|
|
|
|
|
|
|
if args.output_fig == True:
|
|
|
|
|
fig_name = ("error_surface_alphas_%d_betas_%d" %
|
|
|
|
|
(args.num_alphas, args.num_betas))
|
|
|
|
|
plot_error_surface(params_grid, err_ave, fig_name)
|
|
|
|
|
ds2_model.logger.info("output figure %s" % fig_name)
|
|
|
|
|
print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
|
|
|
|
|
% (args.num_batches, "%.3f" % params_grid[min_index][0],
|
|
|
|
|
"%.3f" % params_grid[min_index][1]))
|
|
|
|
|
|
|
|
|
|
ds2_model.logger.info("finish inference")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
print_arguments(args)
|
|
|
|
|
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
|
|
|
|
|