add trainer and utils add setup model and dataloader update travis using Bionic distpull/522/head
parent
508182752e
commit
c2ccb11ba0
@ -1,555 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import collections
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
from paddle.nn import functional as F
|
|
||||||
from paddle.nn import initializer as I
|
|
||||||
|
|
||||||
__all__ = ['DeepSpeech2']
|
|
||||||
|
|
||||||
|
|
||||||
def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
|
||||||
t_min = paddle.to_tensor(t_min)
|
|
||||||
t_max = paddle.to_tensor(t_max)
|
|
||||||
return x.maximum(t_min).minimum(t_max)
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(x_len, max_len=None, dtype='float32'):
|
|
||||||
max_len = max_len or x_len.max()
|
|
||||||
x_len = paddle.unsqueeze(x_len, -1)
|
|
||||||
row_vector = paddle.arange(max_len)
|
|
||||||
mask = row_vector < x_len
|
|
||||||
mask = paddle.cast(mask, dtype)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBn(nn.Layer):
|
|
||||||
"""Convolution layer with batch normalization.
|
|
||||||
|
|
||||||
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
|
|
||||||
two image dimension.
|
|
||||||
:type kernel_size: int|tuple|list
|
|
||||||
:param num_channels_in: Number of input channels.
|
|
||||||
:type num_channels_in: int
|
|
||||||
:param num_channels_out: Number of output channels.
|
|
||||||
:type num_channels_out: int
|
|
||||||
:param stride: The x dimension of the stride. Or input a tuple for two
|
|
||||||
image dimension.
|
|
||||||
:type stride: int|tuple|list
|
|
||||||
:param padding: The x dimension of the padding. Or input a tuple for two
|
|
||||||
image dimension.
|
|
||||||
:type padding: int|tuple|list
|
|
||||||
:param act: Activation type, relu|brelu
|
|
||||||
:type act: string
|
|
||||||
:param masks: Masks data layer to reset padding.
|
|
||||||
:type masks: Variable
|
|
||||||
:param name: Name of the layer.
|
|
||||||
:param name: string
|
|
||||||
:return: Batch norm layer after convolution layer.
|
|
||||||
:rtype: Variable
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
|
|
||||||
padding, act):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
|
|
||||||
self.conv = nn.Conv2D(
|
|
||||||
num_channels_in,
|
|
||||||
num_channels_out,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
weight_attr=None,
|
|
||||||
bias_attr=None,
|
|
||||||
data_format='NCHW')
|
|
||||||
|
|
||||||
self.bn = nn.BatchNorm2D(
|
|
||||||
num_channels_out,
|
|
||||||
weight_attr=None,
|
|
||||||
bias_attr=None,
|
|
||||||
data_format='NCHW')
|
|
||||||
self.act = paddle.relu if act == 'relu' else brelu
|
|
||||||
|
|
||||||
def forward(self, x, x_len):
|
|
||||||
"""
|
|
||||||
x(Tensor): audio, shape [B, C, D, T]
|
|
||||||
"""
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.bn(x)
|
|
||||||
x = self.act(x)
|
|
||||||
|
|
||||||
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
|
|
||||||
) // self.stride[1] + 1
|
|
||||||
|
|
||||||
# reset padding part to 0
|
|
||||||
masks = sequence_mask(x_len) #[B, T]
|
|
||||||
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
|
|
||||||
x = x.multiply(masks)
|
|
||||||
|
|
||||||
return x, x_len
|
|
||||||
|
|
||||||
|
|
||||||
class ConvStack(nn.Layer):
|
|
||||||
"""Convolution group with stacked convolution layers.
|
|
||||||
|
|
||||||
:param feat_size: audio feature dim.
|
|
||||||
:type feat_size: int
|
|
||||||
:param num_stacks: Number of stacked convolution layers.
|
|
||||||
:type num_stacks: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, feat_size, num_stacks):
|
|
||||||
super().__init__()
|
|
||||||
self.feat_size = feat_size # D
|
|
||||||
self.num_stacks = num_stacks
|
|
||||||
|
|
||||||
self.filter_size = (41, 11) # [D, T]
|
|
||||||
self.stride = (2, 3)
|
|
||||||
self.padding = (20, 5)
|
|
||||||
self.conv_in = ConvBn(
|
|
||||||
num_channels_in=1,
|
|
||||||
num_channels_out=32,
|
|
||||||
kernel_size=self.filter_size,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
act='brelu', )
|
|
||||||
|
|
||||||
out_channel = 32
|
|
||||||
self.conv_stack = nn.LayerList([
|
|
||||||
ConvBn(
|
|
||||||
num_channels_in=32,
|
|
||||||
num_channels_out=out_channel,
|
|
||||||
kernel_size=(21, 11),
|
|
||||||
stride=(2, 1),
|
|
||||||
padding=(10, 5),
|
|
||||||
act='brelu') for i in range(num_stacks - 1)
|
|
||||||
])
|
|
||||||
|
|
||||||
# conv output feat_dim
|
|
||||||
output_height = (feat_size - 1) // 2 + 1
|
|
||||||
for i in range(self.num_stacks - 1):
|
|
||||||
output_height = (output_height - 1) // 2 + 1
|
|
||||||
self.output_height = out_channel * output_height
|
|
||||||
|
|
||||||
def forward(self, x, x_len):
|
|
||||||
"""
|
|
||||||
x: shape [B, C, D, T]
|
|
||||||
x_len : shape [B]
|
|
||||||
"""
|
|
||||||
x, x_len = self.conv_in(x, x_len)
|
|
||||||
for i, conv in enumerate(self.conv_stack):
|
|
||||||
x, x_len = conv(x, x_len)
|
|
||||||
return x, x_len
|
|
||||||
|
|
||||||
|
|
||||||
class RNNCell(nn.RNNCellBase):
|
|
||||||
r"""
|
|
||||||
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
|
|
||||||
computes the outputs and updates states.
|
|
||||||
The formula used is as follows:
|
|
||||||
.. math::
|
|
||||||
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
|
|
||||||
y_{t} & = h_{t}
|
|
||||||
|
|
||||||
where :math:`act` is for :attr:`activation`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
hidden_size,
|
|
||||||
activation="tanh",
|
|
||||||
weight_ih_attr=None,
|
|
||||||
weight_hh_attr=None,
|
|
||||||
bias_ih_attr=None,
|
|
||||||
bias_hh_attr=None,
|
|
||||||
name=None):
|
|
||||||
super().__init__()
|
|
||||||
std = 1.0 / math.sqrt(hidden_size)
|
|
||||||
self.weight_hh = self.create_parameter(
|
|
||||||
(hidden_size, hidden_size),
|
|
||||||
weight_hh_attr,
|
|
||||||
default_initializer=I.Uniform(-std, std))
|
|
||||||
self.bias_ih = self.create_parameter(
|
|
||||||
(hidden_size, ),
|
|
||||||
bias_ih_attr,
|
|
||||||
is_bias=True,
|
|
||||||
default_initializer=I.Uniform(-std, std))
|
|
||||||
self.bias_hh = self.create_parameter(
|
|
||||||
(hidden_size, ),
|
|
||||||
bias_hh_attr,
|
|
||||||
is_bias=True,
|
|
||||||
default_initializer=I.Uniform(-std, std))
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
if activation not in ["tanh", "relu", "brelu"]:
|
|
||||||
raise ValueError(
|
|
||||||
"activation for SimpleRNNCell should be tanh or relu, "
|
|
||||||
"but get {}".format(activation))
|
|
||||||
self.activation = activation
|
|
||||||
self._activation_fn = paddle.tanh \
|
|
||||||
if activation == "tanh" \
|
|
||||||
else F.relu
|
|
||||||
if activation == 'brelu':
|
|
||||||
self._activation_fn = brelu
|
|
||||||
|
|
||||||
def forward(self, inputs, states=None):
|
|
||||||
if states is None:
|
|
||||||
states = self.get_initial_states(inputs, self.state_shape)
|
|
||||||
pre_h = states
|
|
||||||
i2h = inputs
|
|
||||||
if self.bias_ih is not None:
|
|
||||||
i2h += self.bias_ih
|
|
||||||
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
|
|
||||||
if self.bias_hh is not None:
|
|
||||||
h2h += self.bias_hh
|
|
||||||
h = self._activation_fn(i2h + h2h)
|
|
||||||
return h, h
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state_shape(self):
|
|
||||||
return (self.hidden_size, )
|
|
||||||
|
|
||||||
|
|
||||||
class GRUCellShare(nn.RNNCellBase):
|
|
||||||
r"""
|
|
||||||
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
|
|
||||||
it computes the outputs and updates states.
|
|
||||||
The formula for GRU used is as follows:
|
|
||||||
.. math::
|
|
||||||
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
|
|
||||||
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
|
|
||||||
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
|
|
||||||
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
|
|
||||||
y_{t} & = h_{t}
|
|
||||||
|
|
||||||
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
|
|
||||||
multiplication operator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
input_size,
|
|
||||||
hidden_size,
|
|
||||||
weight_ih_attr=None,
|
|
||||||
weight_hh_attr=None,
|
|
||||||
bias_ih_attr=None,
|
|
||||||
bias_hh_attr=None,
|
|
||||||
name=None):
|
|
||||||
super().__init__()
|
|
||||||
std = 1.0 / math.sqrt(hidden_size)
|
|
||||||
self.weight_hh = self.create_parameter(
|
|
||||||
(3 * hidden_size, hidden_size),
|
|
||||||
weight_hh_attr,
|
|
||||||
default_initializer=I.Uniform(-std, std))
|
|
||||||
self.bias_ih = self.create_parameter(
|
|
||||||
(3 * hidden_size, ),
|
|
||||||
bias_ih_attr,
|
|
||||||
is_bias=True,
|
|
||||||
default_initializer=I.Uniform(-std, std))
|
|
||||||
self.bias_hh = self.create_parameter(
|
|
||||||
(3 * hidden_size, ),
|
|
||||||
bias_hh_attr,
|
|
||||||
is_bias=True,
|
|
||||||
default_initializer=I.Uniform(-std, std))
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.input_size = input_size
|
|
||||||
self._gate_activation = F.sigmoid
|
|
||||||
self._activation = paddle.tanh
|
|
||||||
|
|
||||||
def forward(self, inputs, states=None):
|
|
||||||
if states is None:
|
|
||||||
states = self.get_initial_states(inputs, self.state_shape)
|
|
||||||
|
|
||||||
pre_hidden = states
|
|
||||||
x_gates = inputs
|
|
||||||
if self.bias_ih is not None:
|
|
||||||
x_gates = x_gates + self.bias_ih
|
|
||||||
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
|
|
||||||
if self.bias_hh is not None:
|
|
||||||
h_gates = h_gates + self.bias_hh
|
|
||||||
|
|
||||||
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
|
|
||||||
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
|
|
||||||
|
|
||||||
r = self._gate_activation(x_r + h_r)
|
|
||||||
z = self._gate_activation(x_z + h_z)
|
|
||||||
c = self._activation(x_c + r * h_c) # apply reset gate after mm
|
|
||||||
h = (pre_hidden - c) * z + c
|
|
||||||
|
|
||||||
return h, h
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state_shape(self):
|
|
||||||
r"""
|
|
||||||
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
|
|
||||||
size would be automatically inserted into shape). The shape corresponds
|
|
||||||
to the shape of :math:`h_{t-1}`.
|
|
||||||
"""
|
|
||||||
return (self.hidden_size, )
|
|
||||||
|
|
||||||
|
|
||||||
class BiRNNWithBN(nn.Layer):
|
|
||||||
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
|
||||||
The batch normalization is only performed on input-state weights.
|
|
||||||
|
|
||||||
:param name: Name of the layer parameters.
|
|
||||||
:type name: string
|
|
||||||
:param size: Dimension of RNN cells.
|
|
||||||
:type size: int
|
|
||||||
:param share_weights: Whether to share input-hidden weights between
|
|
||||||
forward and backward directional RNNs.
|
|
||||||
:type share_weights: bool
|
|
||||||
:return: Bidirectional simple rnn layer.
|
|
||||||
:rtype: Variable
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, i_size, h_size, share_weights):
|
|
||||||
super().__init__()
|
|
||||||
self.share_weights = share_weights
|
|
||||||
self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32))
|
|
||||||
if self.share_weights:
|
|
||||||
#input-hidden weights shared between bi-directional rnn.
|
|
||||||
self.fw_fc = nn.Linear(i_size, h_size)
|
|
||||||
# batch norm is only performed on input-state projection
|
|
||||||
self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
|
|
||||||
self.bw_fc = self.fw_fc
|
|
||||||
self.bw_bn = self.fw_bn
|
|
||||||
else:
|
|
||||||
self.fw_fc = nn.Linear(i_size, h_size)
|
|
||||||
self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
|
|
||||||
self.bw_fc = nn.Linear(i_size, h_size)
|
|
||||||
self.bw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
|
|
||||||
|
|
||||||
self.fw_cell = RNNCell(hidden_size=h_size, activation='relu')
|
|
||||||
self.bw_cell = RNNCell(
|
|
||||||
hidden_size=h_size,
|
|
||||||
activation='relu', )
|
|
||||||
self.fw_rnn = nn.RNN(
|
|
||||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
||||||
self.bw_rnn = nn.RNN(
|
|
||||||
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
||||||
|
|
||||||
def forward(self, x, x_len):
|
|
||||||
# x, shape [B, T, D]
|
|
||||||
fw_x = self.fw_bn(self.fw_fc(x))
|
|
||||||
bw_x = self.bw_bn(self.bw_fc(x))
|
|
||||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
|
||||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
|
||||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
|
||||||
return x, x_len
|
|
||||||
|
|
||||||
|
|
||||||
class BiGRUWithBN(nn.Layer):
|
|
||||||
"""Bidirectonal gru layer with sequence-wise batch normalization.
|
|
||||||
The batch normalization is only performed on input-state weights.
|
|
||||||
|
|
||||||
:param name: Name of the layer.
|
|
||||||
:type name: string
|
|
||||||
:param input: Input layer.
|
|
||||||
:type input: Variable
|
|
||||||
:param size: Dimension of GRU cells.
|
|
||||||
:type size: int
|
|
||||||
:param act: Activation type.
|
|
||||||
:type act: string
|
|
||||||
:return: Bidirectional GRU layer.
|
|
||||||
:rtype: Variable
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, i_size, h_size, act):
|
|
||||||
super().__init__()
|
|
||||||
hidden_size = h_size * 3
|
|
||||||
self.fw_fc = nn.Linear(i_size, hidden_size)
|
|
||||||
self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
|
|
||||||
self.bw_fc = nn.Linear(i_size, hidden_size)
|
|
||||||
self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
|
|
||||||
|
|
||||||
self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
|
|
||||||
self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
|
|
||||||
self.fw_rnn = nn.RNN(
|
|
||||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
|
||||||
self.bw_rnn = nn.RNN(
|
|
||||||
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
|
||||||
|
|
||||||
def forward(self, x, x_len):
|
|
||||||
# x, shape [B, T, D]
|
|
||||||
fw_x = self.fw_bn(self.fw_fc(x))
|
|
||||||
bw_x = self.bw_bn(self.bw_fc(x))
|
|
||||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
|
||||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
|
||||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
|
||||||
return x, x_len
|
|
||||||
|
|
||||||
|
|
||||||
class RNNStack(nn.Layer):
|
|
||||||
"""RNN group with stacked bidirectional simple RNN or GRU layers.
|
|
||||||
|
|
||||||
:param input: Input layer.
|
|
||||||
:type input: Variable
|
|
||||||
:param size: Dimension of RNN cells in each layer.
|
|
||||||
:type size: int
|
|
||||||
:param num_stacks: Number of stacked rnn layers.
|
|
||||||
:type num_stacks: int
|
|
||||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
|
||||||
:type use_gru: bool
|
|
||||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
|
||||||
forward and backward directional RNNs.
|
|
||||||
It is only available when use_gru=False.
|
|
||||||
:type share_weights: bool
|
|
||||||
:return: Output layer of the RNN group.
|
|
||||||
:rtype: Variable
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):
|
|
||||||
super().__init__()
|
|
||||||
self.rnn_stacks = nn.LayerList()
|
|
||||||
for i in range(num_stacks):
|
|
||||||
if use_gru:
|
|
||||||
#default:GRU using tanh
|
|
||||||
self.rnn_stacks.append(
|
|
||||||
BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu"))
|
|
||||||
else:
|
|
||||||
self.rnn_stacks.append(
|
|
||||||
BiRNNWithBN(
|
|
||||||
i_size=i_size,
|
|
||||||
h_size=h_size,
|
|
||||||
share_weights=share_rnn_weights))
|
|
||||||
i_size = h_size * 2
|
|
||||||
|
|
||||||
def forward(self, x, x_len):
|
|
||||||
"""
|
|
||||||
x: shape [B, T, D]
|
|
||||||
x_len: shpae [B]
|
|
||||||
"""
|
|
||||||
for i, rnn in enumerate(self.rnn_stacks):
|
|
||||||
x, x_len = rnn(x, x_len)
|
|
||||||
return x, x_len
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSpeech2(nn.Layer):
|
|
||||||
"""The DeepSpeech2 network structure.
|
|
||||||
|
|
||||||
:param audio_data: Audio spectrogram data layer.
|
|
||||||
:type audio_data: Variable
|
|
||||||
:param text_data: Transcription text data layer.
|
|
||||||
:type text_data: Variable
|
|
||||||
:param audio_len: Valid sequence length data layer.
|
|
||||||
:type audio_len: Variable
|
|
||||||
:param masks: Masks data layer to reset padding.
|
|
||||||
:type masks: Variable
|
|
||||||
:param dict_size: Dictionary size for tokenized transcription.
|
|
||||||
:type dict_size: int
|
|
||||||
:param num_conv_layers: Number of stacking convolution layers.
|
|
||||||
:type num_conv_layers: int
|
|
||||||
:param num_rnn_layers: Number of stacking RNN layers.
|
|
||||||
:type num_rnn_layers: int
|
|
||||||
:param rnn_size: RNN layer size (dimension of RNN cells).
|
|
||||||
:type rnn_size: int
|
|
||||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
|
||||||
:type use_gru: bool
|
|
||||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
|
||||||
forward and backward direction RNNs.
|
|
||||||
It is only available when use_gru=False.
|
|
||||||
:type share_weights: bool
|
|
||||||
:return: A tuple of an output unnormalized log probability layer (
|
|
||||||
before softmax) and a ctc cost layer.
|
|
||||||
:rtype: tuple of LayerOutput
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
feat_size,
|
|
||||||
dict_size,
|
|
||||||
num_conv_layers=2,
|
|
||||||
num_rnn_layers=3,
|
|
||||||
rnn_size=256,
|
|
||||||
use_gru=False,
|
|
||||||
share_rnn_weights=True):
|
|
||||||
super().__init__()
|
|
||||||
self.feat_size = feat_size # 161 for linear
|
|
||||||
self.dict_size = dict_size
|
|
||||||
|
|
||||||
self.conv = ConvStack(feat_size, num_conv_layers)
|
|
||||||
|
|
||||||
i_size = self.conv.output_height # H after conv stack
|
|
||||||
self.rnn = RNNStack(
|
|
||||||
i_size=i_size,
|
|
||||||
h_size=rnn_size,
|
|
||||||
num_stacks=num_rnn_layers,
|
|
||||||
use_gru=use_gru,
|
|
||||||
share_rnn_weights=share_rnn_weights)
|
|
||||||
self.fc = nn.Linear(rnn_size * 2, dict_size + 1)
|
|
||||||
|
|
||||||
def predict(self, audio, audio_len):
|
|
||||||
# [B, D, T] -> [B, C=1, D, T]
|
|
||||||
audio = audio.unsqueeze(1)
|
|
||||||
|
|
||||||
# convolution group
|
|
||||||
x, audio_len = self.conv(audio, audio_len)
|
|
||||||
|
|
||||||
# convert data from convolution feature map to sequence of vectors
|
|
||||||
B, C, D, T = paddle.shape(x)
|
|
||||||
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
|
|
||||||
x = x.reshape([B, T, C * D]) #[B, T, C*D]
|
|
||||||
|
|
||||||
# remove padding part
|
|
||||||
x, audio_len = self.rnn(x, audio_len) #[B, T, D]
|
|
||||||
|
|
||||||
logits = self.fc(x) #[B, T, V + 1]
|
|
||||||
|
|
||||||
#ctcdecoder need probs, not log_probs
|
|
||||||
probs = F.softmax(logits)
|
|
||||||
|
|
||||||
return logits, probs
|
|
||||||
|
|
||||||
@paddle.no_grad()
|
|
||||||
def infer(self, audio, audio_len):
|
|
||||||
_, probs = self.predict(audio, audio_len)
|
|
||||||
return probs
|
|
||||||
|
|
||||||
def forward(self, audio, text, audio_len, text_len):
|
|
||||||
"""
|
|
||||||
audio: shape [B, D, T]
|
|
||||||
text: shape [B, T]
|
|
||||||
audio_len: shape [B]
|
|
||||||
text_len: shape [B]
|
|
||||||
"""
|
|
||||||
logits, probs = self.predict(audio, audio_len)
|
|
||||||
print(logits.shape)
|
|
||||||
print(text.shape)
|
|
||||||
print(audio_len.shape)
|
|
||||||
print(text_len.shape)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSpeechLoss(nn.Layer):
|
|
||||||
def __init__(self, vocab_size):
|
|
||||||
super().__init__()
|
|
||||||
self.loss = nn.CTCLoss(blank=vocab_size, reduction='none')
|
|
||||||
|
|
||||||
def forward(self, logits, text, audio_len, text_len):
|
|
||||||
# warp-ctc do softmax on activations
|
|
||||||
# warp-ctc need activation with shape [T, B, V + 1]
|
|
||||||
logits = logits.transpose([1, 0, 2])
|
|
||||||
|
|
||||||
ctc_loss = self.loss(logits, text, audio_len, text_len)
|
|
||||||
ctc_loss /= text_len # norm_by_times
|
|
||||||
ctc_loss = ctc_loss.sum()
|
|
||||||
return ctc_loss
|
|
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader, DistributedBatchSampler
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import parakeet
|
||||||
|
from parakeet.utils import checkpoint, mp_tools
|
||||||
|
|
||||||
|
__all__ = ["ExperimentBase"]
|
||||||
|
|
||||||
|
|
||||||
|
class ExperimentBase(object):
|
||||||
|
"""
|
||||||
|
An experiment template in order to structure the training code and take
|
||||||
|
care of saving, loading, logging, visualization stuffs. It's intended to
|
||||||
|
be flexible and simple.
|
||||||
|
|
||||||
|
So it only handles output directory (create directory for the output,
|
||||||
|
create a checkpoint directory, dump the config in use and create
|
||||||
|
visualizer and logger) in a standard way without enforcing any
|
||||||
|
input-output protocols to the model and dataloader. It leaves the main
|
||||||
|
part for the user to implement their own (setup the model, criterion,
|
||||||
|
optimizer, define a training step, define a validation function and
|
||||||
|
customize all the text and visual logs).
|
||||||
|
It does not save too much boilerplate code. The users still have to write
|
||||||
|
the forward/backward/update mannually, but they are free to add
|
||||||
|
non-standard behaviors if needed.
|
||||||
|
We have some conventions to follow.
|
||||||
|
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
|
||||||
|
``valid_loader``, ``config`` and ``args`` attributes.
|
||||||
|
2. The config should have a ``training`` field, which has
|
||||||
|
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
|
||||||
|
used as the trigger to invoke validation, checkpointing and stop of the
|
||||||
|
experiment.
|
||||||
|
3. There are four methods, namely ``train_batch``, ``valid``,
|
||||||
|
``setup_model`` and ``setup_dataloader`` that should be implemented.
|
||||||
|
Feel free to add/overwrite other methods and standalone functions if you
|
||||||
|
need.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config: yacs.config.CfgNode
|
||||||
|
The configuration used for the experiment.
|
||||||
|
|
||||||
|
args: argparse.Namespace
|
||||||
|
The parsed command line arguments.
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> def main_sp(config, args):
|
||||||
|
>>> exp = Experiment(config, args)
|
||||||
|
>>> exp.setup()
|
||||||
|
>>> exp.run()
|
||||||
|
>>>
|
||||||
|
>>> config = get_cfg_defaults()
|
||||||
|
>>> parser = default_argument_parser()
|
||||||
|
>>> args = parser.parse_args()
|
||||||
|
>>> if args.config:
|
||||||
|
>>> config.merge_from_file(args.config)
|
||||||
|
>>> if args.opts:
|
||||||
|
>>> config.merge_from_list(args.opts)
|
||||||
|
>>> config.freeze()
|
||||||
|
>>>
|
||||||
|
>>> if args.nprocs > 1 and args.device == "gpu":
|
||||||
|
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||||
|
>>> else:
|
||||||
|
>>> main_sp(config, args)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, args):
|
||||||
|
self.config = config
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""Setup the experiment.
|
||||||
|
"""
|
||||||
|
paddle.set_device(self.args.device)
|
||||||
|
if self.parallel:
|
||||||
|
self.init_parallel()
|
||||||
|
|
||||||
|
self.setup_output_dir()
|
||||||
|
self.dump_config()
|
||||||
|
self.setup_visualizer()
|
||||||
|
self.setup_logger()
|
||||||
|
self.setup_checkpointer()
|
||||||
|
|
||||||
|
self.setup_dataloader()
|
||||||
|
self.setup_model()
|
||||||
|
|
||||||
|
self.iteration = 0
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parallel(self):
|
||||||
|
"""A flag indicating whether the experiment should run with
|
||||||
|
multiprocessing.
|
||||||
|
"""
|
||||||
|
return self.args.device == "gpu" and self.args.nprocs > 1
|
||||||
|
|
||||||
|
def init_parallel(self):
|
||||||
|
"""Init environment for multiprocess training.
|
||||||
|
"""
|
||||||
|
dist.init_parallel_env()
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""Save checkpoint (model parameters and optimizer states).
|
||||||
|
"""
|
||||||
|
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
|
||||||
|
self.model, self.optimizer)
|
||||||
|
|
||||||
|
def load_or_resume(self):
|
||||||
|
"""Resume from latest checkpoint at checkpoints in the output
|
||||||
|
directory or load a specified checkpoint.
|
||||||
|
|
||||||
|
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
||||||
|
resume training.
|
||||||
|
"""
|
||||||
|
iteration = checkpoint.load_parameters(
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
|
checkpoint_path=self.args.checkpoint_path)
|
||||||
|
self.iteration = iteration
|
||||||
|
|
||||||
|
def read_batch(self):
|
||||||
|
"""Read a batch from the train_loader.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[Tensor]
|
||||||
|
A batch.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
batch = next(self.iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.new_epoch()
|
||||||
|
batch = next(self.iterator)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def new_epoch(self):
|
||||||
|
"""Reset the train loader and increment ``epoch``.
|
||||||
|
"""
|
||||||
|
self.epoch += 1
|
||||||
|
if self.parallel:
|
||||||
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
||||||
|
self.iterator = iter(self.train_loader)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""The training process.
|
||||||
|
|
||||||
|
It includes forward/backward/update and periodical validation and
|
||||||
|
saving.
|
||||||
|
"""
|
||||||
|
self.new_epoch()
|
||||||
|
while self.iteration < self.config.training.max_iteration:
|
||||||
|
self.iteration += 1
|
||||||
|
self.train_batch()
|
||||||
|
|
||||||
|
if self.iteration % self.config.training.valid_interval == 0:
|
||||||
|
self.valid()
|
||||||
|
|
||||||
|
if self.iteration % self.config.training.save_interval == 0:
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""The routine of the experiment after setup. This method is intended
|
||||||
|
to be used by the user.
|
||||||
|
"""
|
||||||
|
self.load_or_resume()
|
||||||
|
try:
|
||||||
|
self.train()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
self.save()
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def setup_output_dir(self):
|
||||||
|
"""Create a directory used for output.
|
||||||
|
"""
|
||||||
|
# output dir
|
||||||
|
output_dir = Path(self.args.output).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def setup_checkpointer(self):
|
||||||
|
"""Create a directory used to save checkpoints into.
|
||||||
|
|
||||||
|
It is "checkpoints" inside the output directory.
|
||||||
|
"""
|
||||||
|
# checkpoint dir
|
||||||
|
checkpoint_dir = self.output_dir / "checkpoints"
|
||||||
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def setup_visualizer(self):
|
||||||
|
"""Initialize a visualizer to log the experiment.
|
||||||
|
|
||||||
|
The visual log is saved in the output directory.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
Only the main process has a visualizer with it. Use multiple
|
||||||
|
visualizers in multiprocess to write to a same log file may cause
|
||||||
|
unexpected behaviors.
|
||||||
|
"""
|
||||||
|
# visualizer
|
||||||
|
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
||||||
|
|
||||||
|
self.visualizer = visualizer
|
||||||
|
|
||||||
|
def setup_logger(self):
|
||||||
|
"""Initialize a text logger to log the experiment.
|
||||||
|
|
||||||
|
Each process has its own text logger. The logging message is write to
|
||||||
|
the standard output and a text file named ``worker_n.log`` in the
|
||||||
|
output directory, where ``n`` means the rank of the process.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel("INFO")
|
||||||
|
logger.addHandler(logging.StreamHandler())
|
||||||
|
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
logger.addHandler(logging.FileHandler(str(log_file)))
|
||||||
|
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def dump_config(self):
|
||||||
|
"""Save the configuration used for this experiment.
|
||||||
|
|
||||||
|
It is saved in to ``config.yaml`` in the output directory at the
|
||||||
|
beginning of the experiment.
|
||||||
|
"""
|
||||||
|
with open(self.output_dir / "config.yaml", 'wt') as f:
|
||||||
|
print(self.config, file=f)
|
||||||
|
|
||||||
|
def train_batch(self):
|
||||||
|
"""The training loop. A subclass should implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("train_batch should be implemented.")
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
@paddle.no_grad()
|
||||||
|
def valid(self):
|
||||||
|
"""The validation. A subclass should implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("valid should be implemented.")
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
"""Setup model, criterion and optimizer, etc. A subclass should
|
||||||
|
implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("setup_model should be implemented.")
|
||||||
|
|
||||||
|
def setup_dataloader(self):
|
||||||
|
"""Setup training dataloader and validation dataloader. A subclass
|
||||||
|
should implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("setup_dataloader should be implemented.")
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from training.trainer import *
|
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
from utils import checkpoint
|
||||||
|
from utils import mp_tools
|
||||||
|
|
||||||
|
__all__ = ["Trainer"]
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer():
|
||||||
|
"""
|
||||||
|
An experiment template in order to structure the training code and take
|
||||||
|
care of saving, loading, logging, visualization stuffs. It's intended to
|
||||||
|
be flexible and simple.
|
||||||
|
|
||||||
|
So it only handles output directory (create directory for the output,
|
||||||
|
create a checkpoint directory, dump the config in use and create
|
||||||
|
visualizer and logger) in a standard way without enforcing any
|
||||||
|
input-output protocols to the model and dataloader. It leaves the main
|
||||||
|
part for the user to implement their own (setup the model, criterion,
|
||||||
|
optimizer, define a training step, define a validation function and
|
||||||
|
customize all the text and visual logs).
|
||||||
|
It does not save too much boilerplate code. The users still have to write
|
||||||
|
the forward/backward/update mannually, but they are free to add
|
||||||
|
non-standard behaviors if needed.
|
||||||
|
We have some conventions to follow.
|
||||||
|
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
|
||||||
|
``valid_loader``, ``config`` and ``args`` attributes.
|
||||||
|
2. The config should have a ``training`` field, which has
|
||||||
|
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
|
||||||
|
used as the trigger to invoke validation, checkpointing and stop of the
|
||||||
|
experiment.
|
||||||
|
3. There are four methods, namely ``train_batch``, ``valid``,
|
||||||
|
``setup_model`` and ``setup_dataloader`` that should be implemented.
|
||||||
|
Feel free to add/overwrite other methods and standalone functions if you
|
||||||
|
need.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config: yacs.config.CfgNode
|
||||||
|
The configuration used for the experiment.
|
||||||
|
|
||||||
|
args: argparse.Namespace
|
||||||
|
The parsed command line arguments.
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> def main_sp(config, args):
|
||||||
|
>>> exp = Trainer(config, args)
|
||||||
|
>>> exp.setup()
|
||||||
|
>>> exp.run()
|
||||||
|
>>>
|
||||||
|
>>> config = get_cfg_defaults()
|
||||||
|
>>> parser = default_argument_parser()
|
||||||
|
>>> args = parser.parse_args()
|
||||||
|
>>> if args.config:
|
||||||
|
>>> config.merge_from_file(args.config)
|
||||||
|
>>> if args.opts:
|
||||||
|
>>> config.merge_from_list(args.opts)
|
||||||
|
>>> config.freeze()
|
||||||
|
>>>
|
||||||
|
>>> if args.nprocs > 1 and args.device == "gpu":
|
||||||
|
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||||
|
>>> else:
|
||||||
|
>>> main_sp(config, args)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, args):
|
||||||
|
self.config = config
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""Setup the experiment.
|
||||||
|
"""
|
||||||
|
paddle.set_device(self.args.device)
|
||||||
|
if self.parallel:
|
||||||
|
self.init_parallel()
|
||||||
|
|
||||||
|
self.setup_output_dir()
|
||||||
|
self.dump_config()
|
||||||
|
self.setup_visualizer()
|
||||||
|
self.setup_logger()
|
||||||
|
self.setup_checkpointer()
|
||||||
|
|
||||||
|
self.setup_dataloader()
|
||||||
|
self.setup_model()
|
||||||
|
|
||||||
|
self.iteration = 0
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parallel(self):
|
||||||
|
"""A flag indicating whether the experiment should run with
|
||||||
|
multiprocessing.
|
||||||
|
"""
|
||||||
|
return self.args.device == "gpu" and self.args.nprocs > 1
|
||||||
|
|
||||||
|
def init_parallel(self):
|
||||||
|
"""Init environment for multiprocess training.
|
||||||
|
"""
|
||||||
|
dist.init_parallel_env()
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""Save checkpoint (model parameters and optimizer states).
|
||||||
|
"""
|
||||||
|
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
|
||||||
|
self.model, self.optimizer)
|
||||||
|
|
||||||
|
def resume_or_load(self):
|
||||||
|
"""Resume from latest checkpoint at checkpoints in the output
|
||||||
|
directory or load a specified checkpoint.
|
||||||
|
|
||||||
|
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
||||||
|
resume training.
|
||||||
|
"""
|
||||||
|
iteration = checkpoint.load_parameters(
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
|
checkpoint_path=self.args.checkpoint_path)
|
||||||
|
self.iteration = iteration
|
||||||
|
|
||||||
|
def read_batch(self):
|
||||||
|
"""Read a batch from the train_loader.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[Tensor]
|
||||||
|
A batch.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
batch = next(self.iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.new_epoch()
|
||||||
|
batch = next(self.iterator)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def new_epoch(self):
|
||||||
|
"""Reset the train loader and increment ``epoch``.
|
||||||
|
"""
|
||||||
|
self.epoch += 1
|
||||||
|
if self.parallel:
|
||||||
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
||||||
|
self.iterator = iter(self.train_loader)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""The training process.
|
||||||
|
|
||||||
|
It includes forward/backward/update and periodical validation and
|
||||||
|
saving.
|
||||||
|
"""
|
||||||
|
self.new_epoch()
|
||||||
|
while self.iteration < self.config.training.max_iteration:
|
||||||
|
self.iteration += 1
|
||||||
|
self.train_batch()
|
||||||
|
|
||||||
|
if self.iteration % self.config.training.valid_interval == 0:
|
||||||
|
self.valid()
|
||||||
|
|
||||||
|
if self.iteration % self.config.training.save_interval == 0:
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""The routine of the experiment after setup. This method is intended
|
||||||
|
to be used by the user.
|
||||||
|
"""
|
||||||
|
self.resume_or_load()
|
||||||
|
try:
|
||||||
|
self.train()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
self.save()
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def setup_output_dir(self):
|
||||||
|
"""Create a directory used for output.
|
||||||
|
"""
|
||||||
|
# output dir
|
||||||
|
output_dir = Path(self.args.output).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def setup_checkpointer(self):
|
||||||
|
"""Create a directory used to save checkpoints into.
|
||||||
|
|
||||||
|
It is "checkpoints" inside the output directory.
|
||||||
|
"""
|
||||||
|
# checkpoint dir
|
||||||
|
checkpoint_dir = self.output_dir / "checkpoints"
|
||||||
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def setup_visualizer(self):
|
||||||
|
"""Initialize a visualizer to log the experiment.
|
||||||
|
|
||||||
|
The visual log is saved in the output directory.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
Only the main process has a visualizer with it. Use multiple
|
||||||
|
visualizers in multiprocess to write to a same log file may cause
|
||||||
|
unexpected behaviors.
|
||||||
|
"""
|
||||||
|
# visualizer
|
||||||
|
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
||||||
|
|
||||||
|
self.visualizer = visualizer
|
||||||
|
|
||||||
|
def setup_logger(self):
|
||||||
|
"""Initialize a text logger to log the experiment.
|
||||||
|
|
||||||
|
Each process has its own text logger. The logging message is write to
|
||||||
|
the standard output and a text file named ``worker_n.log`` in the
|
||||||
|
output directory, where ``n`` means the rank of the process.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel("INFO")
|
||||||
|
logger.addHandler(logging.StreamHandler())
|
||||||
|
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
logger.addHandler(logging.FileHandler(str(log_file)))
|
||||||
|
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def dump_config(self):
|
||||||
|
"""Save the configuration used for this experiment.
|
||||||
|
|
||||||
|
It is saved in to ``config.yaml`` in the output directory at the
|
||||||
|
beginning of the experiment.
|
||||||
|
"""
|
||||||
|
with open(self.output_dir / "config.yaml", 'wt') as f:
|
||||||
|
print(self.config, file=f)
|
||||||
|
|
||||||
|
def train_batch(self):
|
||||||
|
"""The training loop. A subclass should implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("train_batch should be implemented.")
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
@paddle.no_grad()
|
||||||
|
def valid(self):
|
||||||
|
"""The validation. A subclass should implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("valid should be implemented.")
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
"""Setup model, criterion and optimizer, etc. A subclass should
|
||||||
|
implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("setup_model should be implemented.")
|
||||||
|
|
||||||
|
def setup_dataloader(self):
|
||||||
|
"""Setup training dataloader and validation dataloader. A subclass
|
||||||
|
should implement this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("setup_dataloader should be implemented.")
|
@ -0,0 +1,135 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
|
||||||
|
from utils import mp_tools
|
||||||
|
|
||||||
|
__all__ = ["load_parameters", "save_parameters"]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_latest_checkpoint(checkpoint_dir: str) -> int:
|
||||||
|
"""Get the iteration number corresponding to the latest saved checkpoint.
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
Returns:
|
||||||
|
int: the latest iteration number.
|
||||||
|
"""
|
||||||
|
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
|
if (not os.path.isfile(checkpoint_record)):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Fetch the latest checkpoint index.
|
||||||
|
with open(checkpoint_record, "rt") as handle:
|
||||||
|
latest_checkpoint = handle.readlines()[-1].strip()
|
||||||
|
step = latest_checkpoint.split(":")[-1]
|
||||||
|
iteration = int(step.split("-")[-1])
|
||||||
|
|
||||||
|
return iteration
|
||||||
|
|
||||||
|
|
||||||
|
def _save_checkpoint(checkpoint_dir: str, iteration: int):
|
||||||
|
"""Save the iteration number of the latest model to be checkpointed.
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
iteration (int): the latest iteration number.
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
|
# Update the latest checkpoint index.
|
||||||
|
with open(checkpoint_record, "a+") as handle:
|
||||||
|
handle.write("model_checkpoint_path:step-{}\n".format(iteration))
|
||||||
|
|
||||||
|
|
||||||
|
def load_parameters(model,
|
||||||
|
optimizer=None,
|
||||||
|
checkpoint_dir=None,
|
||||||
|
checkpoint_path=None):
|
||||||
|
"""Load a specific model checkpoint from disk.
|
||||||
|
Args:
|
||||||
|
model (Layer): model to load parameters.
|
||||||
|
optimizer (Optimizer, optional): optimizer to load states if needed.
|
||||||
|
Defaults to None.
|
||||||
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||||
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||||
|
stored in the checkpoint_path and the argument 'checkpoint_dir' will
|
||||||
|
be ignored. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
iteration (int): number of iterations that the loaded checkpoint has
|
||||||
|
been trained.
|
||||||
|
"""
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
|
||||||
|
elif checkpoint_dir is not None:
|
||||||
|
iteration = _load_latest_checkpoint(checkpoint_dir)
|
||||||
|
if iteration == 0:
|
||||||
|
return iteration
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir,
|
||||||
|
"step-{}".format(iteration))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
|
||||||
|
)
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
params_path = checkpoint_path + ".pdparams"
|
||||||
|
model_dict = paddle.load(params_path)
|
||||||
|
model.set_state_dict(model_dict)
|
||||||
|
print(
|
||||||
|
"[checkpoint] Rank {}: loaded model from {}".format(rank, params_path))
|
||||||
|
|
||||||
|
optimizer_path = checkpoint_path + ".pdopt"
|
||||||
|
if optimizer and os.path.isfile(optimizer_path):
|
||||||
|
optimizer_dict = paddle.load(optimizer_path)
|
||||||
|
optimizer.set_state_dict(optimizer_dict)
|
||||||
|
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
|
||||||
|
rank, optimizer_path))
|
||||||
|
|
||||||
|
return iteration
|
||||||
|
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
|
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||||
|
"""Checkpoint the latest trained model parameters.
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
iteration (int): the latest iteration number.
|
||||||
|
model (Layer): model to be checkpointed.
|
||||||
|
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
||||||
|
Defaults to None.
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
|
||||||
|
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
params_path = checkpoint_path + ".pdparams"
|
||||||
|
paddle.save(model_dict, params_path)
|
||||||
|
print("[checkpoint] Saved model to {}".format(params_path))
|
||||||
|
|
||||||
|
if optimizer:
|
||||||
|
opt_dict = optimizer.state_dict()
|
||||||
|
optimizer_path = checkpoint_path + ".pdopt"
|
||||||
|
paddle.save(opt_dict, optimizer_path)
|
||||||
|
print("[checkpoint] Saved optimzier state to {}".format(optimizer_path))
|
||||||
|
|
||||||
|
_save_checkpoint(checkpoint_dir, iteration)
|
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
__all__ = ["rank_zero_only"]
|
||||||
|
|
||||||
|
|
||||||
|
def rank_zero_only(func):
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
Loading…
Reference in new issue