|
|
|
@ -82,7 +82,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
|
|
|
|
|
max_size = sequences[0].size()
|
|
|
|
|
# (TODO Hui Zhang): slice not supprot `end==start`
|
|
|
|
|
# trailing_dims = max_size[1:]
|
|
|
|
|
trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
|
|
|
|
|
trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
|
|
|
|
|
max_len = max([s.shape[0] for s in sequences])
|
|
|
|
|
if batch_first:
|
|
|
|
|
out_dims = (len(sequences), max_len) + trailing_dims
|
|
|
|
|