pull/3984/head
co63oc 7 months ago committed by GitHub
parent 59d641bc14
commit bb77a7f7db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -33,7 +33,7 @@ If applicable, add screenshots to help explain your problem.
- Python Version [e.g. 3.7] - Python Version [e.g. 3.7]
- PaddlePaddle Version [e.g. 2.0.0] - PaddlePaddle Version [e.g. 2.0.0]
- Model Version [e.g. 2.0.0] - Model Version [e.g. 2.0.0]
- GPU/DRIVER Informationo [e.g. Tesla V100-SXM2-32GB/440.64.00] - GPU/DRIVER Information [e.g. Tesla V100-SXM2-32GB/440.64.00]
- CUDA/CUDNN Version [e.g. cuda-10.2] - CUDA/CUDNN Version [e.g. cuda-10.2]
- MKL Version - MKL Version
- TensorRT Version - TensorRT Version

@ -32,7 +32,7 @@ If applicable, add screenshots to help explain your problem.
- Python Version [e.g. 3.7] - Python Version [e.g. 3.7]
- PaddlePaddle Version [e.g. 2.0.0] - PaddlePaddle Version [e.g. 2.0.0]
- Model Version [e.g. 2.0.0] - Model Version [e.g. 2.0.0]
- GPU/DRIVER Informationo [e.g. Tesla V100-SXM2-32GB/440.64.00] - GPU/DRIVER Information [e.g. Tesla V100-SXM2-32GB/440.64.00]
- CUDA/CUDNN Version [e.g. cuda-10.2] - CUDA/CUDNN Version [e.g. cuda-10.2]
- MKL Version - MKL Version
- TensorRT Version - TensorRT Version

@ -61,7 +61,7 @@ def resample(y: np.ndarray,
if mode == 'kaiser_best': if mode == 'kaiser_best':
warnings.warn( warnings.warn(
f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \ f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \
we recommend the mode kaiser_fast in large scale audio trainning') we recommend the mode kaiser_fast in large scale audio training')
if not isinstance(y, np.ndarray): if not isinstance(y, np.ndarray):
raise ParameterError( raise ParameterError(

@ -233,7 +233,7 @@ def spectrogram(waveform: Tensor,
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. Defaults to True. to FFT. Defaults to True.
sr (int, optional): Sample rate of input waveform. Defaults to 16000. sr (int, optional): Sample rate of input waveform. Defaults to 16000.
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True. is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False. subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey". window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
@ -443,7 +443,7 @@ def fbank(waveform: Tensor,
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. Defaults to True. to FFT. Defaults to True.
sr (int, optional): Sample rate of input waveform. Defaults to 16000. sr (int, optional): Sample rate of input waveform. Defaults to 16000.
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True. is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False. subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False. use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
@ -566,7 +566,7 @@ def mfcc(waveform: Tensor,
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. Defaults to True. to FFT. Defaults to True.
sr (int, optional): Sample rate of input waveform. Defaults to 16000. sr (int, optional): Sample rate of input waveform. Defaults to 16000.
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a signal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True. is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False. subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False. use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.

@ -47,7 +47,7 @@ class AudioClassificationDataset(paddle.io.Dataset):
files (:obj:`List[str]`): A list of absolute path of audio files. files (:obj:`List[str]`): A list of absolute path of audio files.
labels (:obj:`List[int]`): Labels of audio files. labels (:obj:`List[int]`): Labels of audio files.
feat_type (:obj:`str`, `optional`, defaults to `raw`): feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file. It identifies the feature type that user wants to extract of an audio file.
""" """
super(AudioClassificationDataset, self).__init__() super(AudioClassificationDataset, self).__init__()

@ -117,7 +117,7 @@ class ESC50(AudioClassificationDataset):
split (:obj:`int`, `optional`, defaults to 1): split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset. It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`): feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file. It identifies the feature type that user wants to extract of an audio file.
""" """
files, labels = self._get_data(mode, split) files, labels = self._get_data(mode, split)
super(ESC50, self).__init__( super(ESC50, self).__init__(

@ -67,7 +67,7 @@ class GTZAN(AudioClassificationDataset):
split (:obj:`int`, `optional`, defaults to 1): split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset. It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`): feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file. It identifies the feature type that user wants to extract of an audio file.
""" """
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}' assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
files, labels = self._get_data(mode, seed, n_folds, split) files, labels = self._get_data(mode, seed, n_folds, split)

@ -76,7 +76,7 @@ class TESS(AudioClassificationDataset):
split (:obj:`int`, `optional`, defaults to 1): split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset. It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`): feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file. It identifies the feature type that user wants to extract of an audio file.
""" """
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}' assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
files, labels = self._get_data(mode, seed, n_folds, split) files, labels = self._get_data(mode, seed, n_folds, split)

@ -68,7 +68,7 @@ class UrbanSound8K(AudioClassificationDataset):
split (:obj:`int`, `optional`, defaults to 1): split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset. It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`): feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file. It identifies the feature type that user wants to extract of an audio file.
""" """
def _get_meta_info(self): def _get_meta_info(self):

@ -262,8 +262,8 @@ class VoxCeleb(Dataset):
split_chunks: bool=True): split_chunks: bool=True):
print(f'Generating csv: {output_file}') print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav", "start", "stop", "spk_id"] header = ["id", "duration", "wav", "start", "stop", "spk_id"]
# Note: this may occurs c++ execption, but the program will execute fine # Note: this may occurs c++ exception, but the program will execute fine
# so we can ignore the execption # so we can ignore the exception
with Pool(cpu_count()) as p: with Pool(cpu_count()) as p:
infos = list( infos = list(
tqdm( tqdm(

@ -34,7 +34,7 @@ __all__ = [
class Spectrogram(nn.Layer): class Spectrogram(nn.Layer):
"""Compute spectrogram of given signals, typically audio waveforms. """Compute spectrogram of given signals, typically audio waveforms.
The spectorgram is defined as the complex norm of the short-time Fourier transformation. The spectrogram is defined as the complex norm of the short-time Fourier transformation.
Args: Args:
n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512.

@ -247,7 +247,7 @@ def create_dct(n_mfcc: int,
Args: Args:
n_mfcc (int): Number of mel frequency cepstral coefficients. n_mfcc (int): Number of mel frequency cepstral coefficients.
n_mels (int): Number of mel filterbanks. n_mels (int): Number of mel filterbanks.
norm (Optional[str], optional): Normalizaiton type. Defaults to 'ortho'. norm (Optional[str], optional): Normalization type. Defaults to 'ortho'.
dtype (str, optional): The data type of the return matrix. Defaults to 'float32'. dtype (str, optional): The data type of the return matrix. Defaults to 'float32'.
Returns: Returns:

@ -22,8 +22,8 @@ def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
"""Compute EER and return score threshold. """Compute EER and return score threshold.
Args: Args:
labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num labels (np.ndarray): the trial label, shape: [N], one-dimension, N refer to the samples num
scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num scores (np.ndarray): the trial scores, shape: [N], one-dimension, N refer to the samples num
Returns: Returns:
List[float]: eer and the specific threshold List[float]: eer and the specific threshold

@ -121,8 +121,8 @@ def apply_effects_tensor(
""" """
tensor_np = tensor.numpy() tensor_np = tensor.numpy()
ret = paddleaudio._paddleaudio.sox_effects_apply_effects_tensor(tensor_np, sample_rate, ret = paddleaudio._paddleaudio.sox_effects_apply_effects_tensor(
effects, channels_first) tensor_np, sample_rate, effects, channels_first)
if ret is not None: if ret is not None:
return (paddle.to_tensor(ret[0]), ret[1]) return (paddle.to_tensor(ret[0]), ret[1])
raise RuntimeError("Failed to apply sox effect") raise RuntimeError("Failed to apply sox effect")
@ -139,7 +139,7 @@ def apply_effects_file(
Note: Note:
This function works in the way very similar to ``sox`` command, however there are slight This function works in the way very similar to ``sox`` command, however there are slight
differences. For example, ``sox`` commnad adds certain effects automatically (such as differences. For example, ``sox`` command adds certain effects automatically (such as
``rate`` effect after ``speed``, ``pitch`` etc), but this function only applies the given ``rate`` effect after ``speed``, ``pitch`` etc), but this function only applies the given
effects. Therefore, to actually apply ``speed`` effect, you also need to give ``rate`` effects. Therefore, to actually apply ``speed`` effect, you also need to give ``rate``
effect with desired sampling rate, because internally, ``speed`` effects only alter sampling effect with desired sampling rate, because internally, ``speed`` effects only alter sampling
@ -228,14 +228,14 @@ def apply_effects_file(
>>> pass >>> pass
""" """
if hasattr(path, "read"): if hasattr(path, "read"):
ret = paddleaudio._paddleaudio.apply_effects_fileobj(path, effects, normalize, ret = paddleaudio._paddleaudio.apply_effects_fileobj(
channels_first, format) path, effects, normalize, channels_first, format)
if ret is None: if ret is None:
raise RuntimeError("Failed to load audio from {}".format(path)) raise RuntimeError("Failed to load audio from {}".format(path))
return (paddle.to_tensor(ret[0]), ret[1]) return (paddle.to_tensor(ret[0]), ret[1])
path = os.fspath(path) path = os.fspath(path)
ret = paddleaudio._paddleaudio.sox_effects_apply_effects_file(path, effects, normalize, ret = paddleaudio._paddleaudio.sox_effects_apply_effects_file(
channels_first, format) path, effects, normalize, channels_first, format)
if ret is not None: if ret is not None:
return (paddle.to_tensor(ret[0]), ret[1]) return (paddle.to_tensor(ret[0]), ret[1])
raise RuntimeError("Failed to load audio from {}".format(path)) raise RuntimeError("Failed to load audio from {}".format(path))

@ -26,7 +26,7 @@ template <class F>
bool StreamingFeatureTpl<F>::ComputeFeature( bool StreamingFeatureTpl<F>::ComputeFeature(
const std::vector<float>& wav, const std::vector<float>& wav,
std::vector<float>* feats) { std::vector<float>* feats) {
// append remaned waves // append remained waves
int wav_len = wav.size(); int wav_len = wav.size();
if (wav_len == 0) return false; if (wav_len == 0) return false;
int left_len = remained_wav_.size(); int left_len = remained_wav_.size();
@ -38,7 +38,7 @@ bool StreamingFeatureTpl<F>::ComputeFeature(
wav.data(), wav.data(),
wav_len * sizeof(float)); wav_len * sizeof(float));
// cache remaned waves // cache remained waves
knf::FrameExtractionOptions frame_opts = computer_.GetFrameOptions(); knf::FrameExtractionOptions frame_opts = computer_.GetFrameOptions();
int num_frames = knf::NumFrames(waves.size(), frame_opts); int num_frames = knf::NumFrames(waves.size(), frame_opts);
int frame_shift = frame_opts.WindowShift(); int frame_shift = frame_opts.WindowShift();

@ -44,5 +44,5 @@ py::array_t<float> KaldiFeatureWrapper::ComputeFbank(
return result.reshape(shape); return result.reshape(shape);
} }
} // namesapce kaldi } // namespace kaldi
} // namespace paddleaudio } // namespace paddleaudio

@ -12,9 +12,9 @@ using namespace paddleaudio::sox_utils;
namespace paddleaudio::sox_effects { namespace paddleaudio::sox_effects {
// Streaming decoding over file-like object is tricky because libsox operates on // Streaming decoding over file-like object is tricky because libsox operates on
// FILE pointer. The folloing is what `sox` and `play` commands do // FILE pointer. The following is what `sox` and `play` commands do
// - file input -> FILE pointer // - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer // - URL input -> call wget in subprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer // - stdin -> FILE pointer
// //
// We want to, instead, fetch byte strings chunk by chunk, consume them, and // We want to, instead, fetch byte strings chunk by chunk, consume them, and
@ -127,12 +127,12 @@ namespace {
enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown }; enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown };
SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized; SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized;
std::mutex SOX_RESOUCE_STATE_MUTEX; std::mutex SOX_RESOURCE_STATE_MUTEX;
} // namespace } // namespace
void initialize_sox_effects() { void initialize_sox_effects() {
const std::lock_guard<std::mutex> lock(SOX_RESOUCE_STATE_MUTEX); const std::lock_guard<std::mutex> lock(SOX_RESOURCE_STATE_MUTEX);
switch (SOX_RESOURCE_STATE) { switch (SOX_RESOURCE_STATE) {
case NotInitialized: case NotInitialized:
@ -150,7 +150,7 @@ void initialize_sox_effects() {
}; };
void shutdown_sox_effects() { void shutdown_sox_effects() {
const std::lock_guard<std::mutex> lock(SOX_RESOUCE_STATE_MUTEX); const std::lock_guard<std::mutex> lock(SOX_RESOURCE_STATE_MUTEX);
switch (SOX_RESOURCE_STATE) { switch (SOX_RESOURCE_STATE) {
case NotInitialized: case NotInitialized:

@ -14,7 +14,7 @@ namespace {
/// helper classes for passing the location of input tensor and output buffer /// helper classes for passing the location of input tensor and output buffer
/// ///
/// drain/flow callback functions require plaing C style function signature and /// drain/flow callback functions require plain C style function signature and
/// the way to pass extra data is to attach data to sox_effect_t::priv pointer. /// the way to pass extra data is to attach data to sox_effect_t::priv pointer.
/// The following structs will be assigned to sox_effect_t::priv pointer which /// The following structs will be assigned to sox_effect_t::priv pointer which
/// gives sox_effect_t an access to input Tensor and output buffer object. /// gives sox_effect_t an access to input Tensor and output buffer object.
@ -50,7 +50,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
*osamp -= *osamp % num_channels; *osamp -= *osamp % num_channels;
// Slice the input Tensor // Slice the input Tensor
// refacor this module, chunk // refactor this module, chunk
auto i_frame = index / num_channels; auto i_frame = index / num_channels;
auto num_frames = *osamp / num_channels; auto num_frames = *osamp / num_channels;

@ -162,7 +162,7 @@ py::dtype get_dtype(
} }
default: default:
// default to float32 for the other formats, including // default to float32 for the other formats, including
// 32-bit flaoting-point WAV, // 32-bit floating-point WAV,
// MP3, // MP3,
// FLAC, // FLAC,
// VORBIS etc... // VORBIS etc...
@ -177,7 +177,7 @@ py::array convert_to_tensor(
const py::dtype dtype, const py::dtype dtype,
const bool normalize, const bool normalize,
const bool channels_first) { const bool channels_first) {
// todo refector later(SGoat) // todo refactor later(SGoat)
py::array t; py::array t;
uint64_t dummy = 0; uint64_t dummy = 0;
SOX_SAMPLE_LOCALS; SOX_SAMPLE_LOCALS;

@ -76,7 +76,7 @@ py::dtype get_dtype(
/// Tensor. /// Tensor.
/// @param dtype Target dtype. Determines the output dtype and value range in /// @param dtype Target dtype. Determines the output dtype and value range in
/// conjunction with normalization. /// conjunction with normalization.
/// @param noramlize Perform normalization. Only effective when dtype is not /// @param normalize Perform normalization. Only effective when dtype is not
/// kFloat32. When effective, the output tensor is kFloat32 type and value range /// kFloat32. When effective, the output tensor is kFloat32 type and value range
/// is [-1.0, 1.0] /// is [-1.0, 1.0]
/// @param channels_first When True, output Tensor has shape of [num_channels, /// @param channels_first When True, output Tensor has shape of [num_channels,

@ -8,9 +8,9 @@ set(patch_dir ${CMAKE_CURRENT_SOURCE_DIR}/../patches)
set(COMMON_ARGS --quiet --disable-shared --enable-static --prefix=${INSTALL_DIR} --with-pic --disable-dependency-tracking --disable-debug --disable-examples --disable-doc) set(COMMON_ARGS --quiet --disable-shared --enable-static --prefix=${INSTALL_DIR} --with-pic --disable-dependency-tracking --disable-debug --disable-examples --disable-doc)
# To pass custom environment variables to ExternalProject_Add command, # To pass custom environment variables to ExternalProject_Add command,
# we need to do `${CMAKE_COMMAND} -E env ${envs} <COMMANAD>`. # we need to do `${CMAKE_COMMAND} -E env ${envs} <COMMAND>`.
# https://stackoverflow.com/a/62437353 # https://stackoverflow.com/a/62437353
# We constrcut the custom environment variables here # We construct the custom environment variables here
set(envs set(envs
"PKG_CONFIG_PATH=${INSTALL_DIR}/lib/pkgconfig" "PKG_CONFIG_PATH=${INSTALL_DIR}/lib/pkgconfig"
"LDFLAGS=-L${INSTALL_DIR}/lib $ENV{LDFLAGS}" "LDFLAGS=-L${INSTALL_DIR}/lib $ENV{LDFLAGS}"

@ -41,14 +41,14 @@ def download_and_decompress(archives: List[Dict[str, str]],
path: str, path: str,
decompress: bool=True): decompress: bool=True):
""" """
Download archieves and decompress to specific path. Download archives and decompress to specific path.
""" """
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path) os.makedirs(path)
for archive in archives: for archive in archives:
assert 'url' in archive and 'md5' in archive, \ assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' 'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archive.keys())}'
download.get_path_from_url( download.get_path_from_url(
archive['url'], path, archive['md5'], decompress=decompress) archive['url'], path, archive['md5'], decompress=decompress)

@ -58,7 +58,7 @@ log_config = {
class Logger(object): class Logger(object):
''' '''
Deafult logger in PaddleAudio Default logger in PaddleAudio
Args: Args:
name(str) : Logger name, default is 'PaddleAudio' name(str) : Logger name, default is 'PaddleAudio'
''' '''

@ -55,7 +55,7 @@ def set_use_threads(use_threads: bool):
Args: Args:
use_threads (bool): When ``True``, enables ``libsox``'s parallel effects channels processing. use_threads (bool): When ``True``, enables ``libsox``'s parallel effects channels processing.
To use mutlithread, the underlying ``libsox`` has to be compiled with OpenMP support. To use multithread, the underlying ``libsox`` has to be compiled with OpenMP support.
See Also: See Also:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Unility functions for Transformer.""" """Utility functions for Transformer."""
from typing import List from typing import List
from typing import Tuple from typing import Tuple
@ -80,7 +80,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
# assuming trailing dimensions and type of all the Tensors # assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0] # in sequences are same and fetching those from sequences[0]
max_size = paddle.shape(sequences[0]) max_size = paddle.shape(sequences[0])
# (TODO Hui Zhang): slice not supprot `end==start` # (TODO Hui Zhang): slice not support `end==start`
# trailing_dims = max_size[1:] # trailing_dims = max_size[1:]
trailing_dims = tuple( trailing_dims = tuple(
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
@ -94,7 +94,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
length = tensor.shape[0] length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor # use index notation to prevent duplicate references to the tensor
if batch_first: if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not support `end==start`
# TODO (Hui Zhang): set_value op not support int16 # TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor # out_tensor[i, :length, ...] = tensor
@ -103,7 +103,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
else: else:
out_tensor[i, length] = tensor out_tensor[i, length] = tensor
else: else:
# TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not support `end==start`
# out_tensor[:length, i, ...] = tensor # out_tensor[:length, i, ...] = tensor
if length != 0: if length != 0:
out_tensor[:length, i] = tensor out_tensor[:length, i] = tensor

Loading…
Cancel
Save