#!/usr/bin/env python3 -W ignore::DeprecationWarning
# prune model by output names
import argparse
import copy
import sys

import onnx


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model',
        required=True,
        help='Path of directory saved the input model.')
    parser.add_argument(
        '--output_names',
        required=True,
        nargs='+',
        help='The outputs of pruned model.')
    parser.add_argument(
        '--save_file', required=True, help='Path to save the new onnx model.')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_arguments()

    if len(set(args.output_names)) < len(args.output_names):
        print(
            "[ERROR] There's dumplicate name in --output_names, which is not allowed."
        )
        sys.exit(-1)

    model = onnx.load(args.model)

    # collect all node outputs and graph output
    output_tensor_names = set()
    for node in model.graph.node:
        for out in node.output:
            # may contain model output
            output_tensor_names.add(out)

    # for out in model.graph.output:
    #     output_tensor_names.add(out.name)

    for output_name in args.output_names:
        if output_name not in output_tensor_names:
            print(
                "[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
                format(output_name))
            sys.exit(-1)

    output_node_indices = set()  # has output names
    output_to_node = dict()  # all node outputs
    for i, node in enumerate(model.graph.node):
        for out in node.output:
            output_to_node[out] = i
            if out in args.output_names:
                output_node_indices.add(i)

    # from outputs find all the ancestors
    reserved_node_indices = copy.deepcopy(
        output_node_indices)  # nodes need to keep
    reserved_inputs = set()  # model input to keep
    new_output_node_indices = copy.deepcopy(output_node_indices)

    while True and len(new_output_node_indices) > 0:
        output_node_indices = copy.deepcopy(new_output_node_indices)

        new_output_node_indices = set()

        for out_node_idx in output_node_indices:
            # backtrace to parenet
            for ipt in model.graph.node[out_node_idx].input:
                if ipt in output_to_node:
                    reserved_node_indices.add(output_to_node[ipt])
                    new_output_node_indices.add(output_to_node[ipt])
                else:
                    reserved_inputs.add(ipt)

    num_inputs = len(model.graph.input)
    num_outputs = len(model.graph.output)
    num_nodes = len(model.graph.node)
    print(
        f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
    )
    print(f"{len(reserved_node_indices)} node to keep.")

    # del node not to keep
    for idx in range(num_nodes - 1, -1, -1):
        if idx not in reserved_node_indices:
            del model.graph.node[idx]

    # del graph input not to keep
    for idx in range(num_inputs - 1, -1, -1):
        if model.graph.input[idx].name not in reserved_inputs:
            del model.graph.input[idx]

    # del old graph outputs
    for i in range(num_outputs):
        del model.graph.output[0]

    # new graph output as user input
    for out in args.output_names:
        model.graph.output.extend([onnx.ValueInfoProto(name=out)])

    # infer shape
    try:
        from onnx_infer_shape import SymbolicShapeInference
        model = SymbolicShapeInference.infer_shapes(
            model,
            int_max=2**31 - 1,
            auto_merge=True,
            guess_output_rank=False,
            verbose=1)
    except Exception as e:
        print(f"skip infer shape step: {e}")

    # check onnx model
    onnx.checker.check_model(model)
    # save onnx model
    onnx.save(model, args.save_file)
    print("[Finished] The new model saved in {}.".format(args.save_file))
    print("[DEBUG INFO] The inputs of new model: {}".format(
        [x.name for x in model.graph.input]))
    print("[DEBUG INFO] The outputs of new model: {}".format(
        [x.name for x in model.graph.output]))