#! /usr/bin/env python3
'''
Convert a mesh file to another.
'''
from __future__ import print_function

import numpy

import meshio


def _main():
    # Parse command line arguments.
    args = _parse_options()

    # read mesh data
    points, cells, point_data, cell_data, field_data = \
        meshio.read(args.infile, file_format=args.input_format)

    print('Number of points: {}'.format(len(points)))
    print('Elements:')
    for tpe, elems in cells.items():
        print('  Number of {}s: {}'.format(tpe, len(elems)))

    if point_data:
        print('Point data: {}'.format(', '.join(point_data.keys())))

    cell_data_keys = set()
    for cell_type in cell_data:
        cell_data_keys = cell_data_keys.union(cell_data[cell_type].keys())
    if cell_data_keys:
        print('Cell data: {}'.format(', '.join(cell_data_keys)))

    if args.prune:
        cells.pop('vertex', None)
        cells.pop('line', None)
        if 'tetra' in cells:
            # remove_lower_order_cells
            cells.pop('triangle', None)
        # remove_orphaned_nodes.
        # find which nodes are not mentioned in the cells and remove them
        flat_cells = cells['tetra'].flatten()
        orphaned_nodes = numpy.setdiff1d(numpy.arange(len(points)), flat_cells)
        points = numpy.delete(points, orphaned_nodes, axis=0)
        # also adapt the point data
        for key in point_data:
            point_data[key] = numpy.delete(
                    point_data[key],
                    orphaned_nodes,
                    axis=0
                    )

        # reset GLOBAL_ID
        if 'GLOBAL_ID' in point_data:
            point_data['GLOBAL_ID'] = numpy.arange(1, len(points)+1)

        # We now need to adapt the cells too.
        diff = numpy.zeros(len(flat_cells), dtype=flat_cells.dtype)
        for orphan in orphaned_nodes:
            diff[numpy.argwhere(flat_cells > orphan)] += 1
        flat_cells -= diff
        cells['tetra'] = flat_cells.reshape(cells['tetra'].shape)

    # Some converters (like VTK) require `points` to be contiguous.
    points = numpy.ascontiguousarray(points)

    # write it out
    meshio.write(
        args.outfile,
        points,
        cells,
        file_format=args.output_format,
        point_data=point_data,
        cell_data=cell_data,
        field_data=field_data
        )

    return


def _parse_options():
    '''Parse input options.'''
    import argparse

    parser = argparse.ArgumentParser(
            description=(
                'Convert between mesh formats.'
                )
            )

    parser.add_argument(
        'infile',
        type=str,
        help='mesh file to be read from'
        )

    parser.add_argument(
        '--input-format', '-i',
        type=str,
        choices=meshio.input_filetypes,
        help='input file format',
        default=None
        )

    parser.add_argument(
        '--output-format', '-o',
        type=str,
        choices=meshio.output_filetypes,
        help='output file format',
        default=None
        )

    parser.add_argument(
        'outfile',
        type=str,
        help='mesh file to be written to'
        )

    parser.add_argument(
            '--prune', '-p',
            action='store_true',
            help='remove lower order cells, remove orphaned nodes'
            )

    parser.add_argument(
        '--version', '-v',
        action='version',
        version='%(prog)s ' + ('(version %s)' % meshio.__version__)
        )

    return parser.parse_args()


if __name__ == '__main__':
    _main()
