pull/1050/head
Junkun 3 years ago
parent e867f3bb41
commit d2fab3238b

@ -102,10 +102,10 @@ def read_manifest(
manifest = [] manifest = []
with jsonlines.open(manifest_path, 'r') as reader: with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader: for json_data in reader:
feat_len = json_data["feat_shape"][ feat_len = json_data["input"][0]["shape"][
0] if 'feat_shape' in json_data else 1.0 0] if 'shape' in json_data["input"][0] else 1.0
token_len = json_data["token_shape"][ token_len = json_data["output"][0]["shape"][
0] if 'token_shape' in json_data else 1.0 0] if 'shape' in json_data["output"][0] else 1.0
conditions = [ conditions = [
feat_len >= min_input_len, feat_len >= min_input_len,
feat_len <= max_input_len, feat_len <= max_input_len,

@ -94,6 +94,9 @@ class Checkpoint():
""" """
configs = {} configs = {}
if len(checkpoint_path) == 0 or checkpoint_path == "None":
checkpoint_path = None
if checkpoint_path is not None: if checkpoint_path is not None:
pass pass
elif checkpoint_dir is not None and record_file is not None: elif checkpoint_dir is not None and record_file is not None:

Loading…
Cancel
Save