format data support multi output

pull/1012/head
Hui Zhang 3 years ago
parent f89f99fe4a
commit 02c7ef3198

@ -2,6 +2,7 @@
set -e
source path.sh
gpus=0,1,2,3
stage=0
stop_stage=100
conf_path=conf/transformer_joint_noam.yaml
@ -21,7 +22,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt}
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then

@ -87,15 +87,24 @@ def main():
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
output_json['output'].append({
'name': 'traget1',
'name': 'target1',
'shape': (len(tokenids), vocab_size),
'text': line,
'token': ' '.join(tokens),
'tokenid': ' '.join(map(str, tokenids)),
})
else:
# isinstance(line, list), multi target
raise NotImplementedError("not support multi output now!")
# isinstance(line, list), multi target in one vocab
for i, item in enumerate(line, 1):
tokens = text_feature.tokenize(item)
tokenids = text_feature.featurize(item)
output_json['output'].append({
'name': f'target{i}',
'shape': (len(tokenids), vocab_size),
'text': item,
'token': ' '.join(tokens),
'tokenid': ' '.join(map(str, tokenids)),
})
# input
line = line_json['feat']

Loading…
Cancel
Save