pull/1050/head
Junkun 3 years ago
parent e867f3bb41
commit d2fab3238b

@ -102,10 +102,10 @@ def read_manifest(
manifest = [] manifest = []
with jsonlines.open(manifest_path, 'r') as reader: with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader: for json_data in reader:
feat_len = json_data["feat_shape"][ feat_len = json_data["input"][0]["shape"][
0] if 'feat_shape' in json_data else 1.0 0] if 'shape' in json_data["input"][0] else 1.0
token_len = json_data["token_shape"][ token_len = json_data["output"][0]["shape"][
0] if 'token_shape' in json_data else 1.0 0] if 'shape' in json_data["output"][0] else 1.0
conditions = [ conditions = [
feat_len >= min_input_len, feat_len >= min_input_len,
feat_len <= max_input_len, feat_len <= max_input_len,

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -94,6 +94,9 @@ class Checkpoint():
""" """
configs = {} configs = {}
if len(checkpoint_path) == 0 or checkpoint_path == "None":
checkpoint_path = None
if checkpoint_path is not None: if checkpoint_path is not None:
pass pass
elif checkpoint_dir is not None and record_file is not None: elif checkpoint_dir is not None and record_file is not None:

Loading…
Cancel
Save