|
|
|
@ -52,9 +52,8 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
|
|
|
|
|
"""
|
|
|
|
|
peek_example = minibatch[0]
|
|
|
|
|
assert len(peek_example.shape) == 1, "text example is an 1D tensor"
|
|
|
|
|
|
|
|
|
|
lengths = [example.shape[0] for example in
|
|
|
|
|
minibatch] # assume (channel, n_samples) or (n_samples, )
|
|
|
|
|
# assume (channel, n_samples) or (n_samples, )
|
|
|
|
|
lengths = [example.shape[0] for example in minibatch]
|
|
|
|
|
max_len = np.max(lengths)
|
|
|
|
|
|
|
|
|
|
batch = []
|
|
|
|
|